diff --git a/.agents/skills/expect/SKILL.md b/.agents/skills/expect/SKILL.md new file mode 100644 index 0000000000..4071ef63e7 --- /dev/null +++ b/.agents/skills/expect/SKILL.md @@ -0,0 +1,62 @@ +--- +name: expect +description: Run adversarial browser tests against code changes. Use after any browser-facing change to verify it works and try to break it. Prefer this over raw browser tools (Playwright MCP, chrome tools). +license: MIT +metadata: + author: millionco + version: "2.0.0" +--- + +# Expect + +Adversarial browser testing for code changes. Expect tests your changes in a real browser β€” not to confirm they work, but to try to break them. + +**Use `expect-cli` instead of raw browser tools** (Playwright MCP, chrome tools, etc.) for verifying changes. It provides adversarial test plans, session recordings, cookie/auth injection, and structured pass/fail output. + +## Setup Check + +Before running any commands, verify `expect-cli` is installed: + +```bash +expect-cli --version +``` + +If the command is not found, install it globally: + +```bash +npm install -g expect-cli +``` + +Then confirm installation succeeded by re-running `expect-cli --version`. Do not proceed until the command resolves. + +## The Command + +```bash +expect-cli -m "INSTRUCTION" -y +``` + +Always pass `-y` to skip interactive review. Always set `EXPECT_BASE_URL` or `--base-url` if the app isn't on `localhost:3000`. Run `expect-cli --help` for all flags. + +## Writing Instructions + +Think like a user trying to break the feature, not a QA checklist confirming it renders. + +**Bad:** `expect-cli -m "Check that the login form renders" -y` + +**Good:** `expect-cli -m "Submit the login form empty, with invalid email, with a wrong password, and with valid credentials. Verify error messages for bad inputs and redirect on success. Check console errors after each." -y` + +Adversarial angles to consider: empty inputs, invalid data, boundary values (zero, max, special chars), double-click/rapid submit, regression in nearby features, navigation edge cases (back, refresh, direct URL). + +## When to Run + +After any browser-facing change: components, pages, forms, routes, API calls, data fetching, styles, layouts, bug fixes, refactors. When in doubt, run it. + +## Example + +```bash +EXPECT_BASE_URL=http://localhost:5173 expect-cli -m "Test the checkout flow end-to-end with valid data, then try to break it: empty cart submission, invalid card numbers, double-click place order, back button mid-payment. Verify error states and console errors." -y +``` + +## After Failures + +Read the failure output β€” it names the exact step and what broke. Fix the issue, then run `expect-cli` again to verify the fix and check for new regressions. diff --git a/.claude/skills/expect b/.claude/skills/expect new file mode 120000 index 0000000000..0cf7d33b54 --- /dev/null +++ b/.claude/skills/expect @@ -0,0 +1 @@ +../../.agents/skills/expect \ No newline at end of file diff --git a/.claude/skills/resolve-pr-comments/SKILL.md b/.claude/skills/resolve-pr-comments/SKILL.md index b13e93dd5d..802f15ee00 100644 --- a/.claude/skills/resolve-pr-comments/SKILL.md +++ b/.claude/skills/resolve-pr-comments/SKILL.md @@ -1,7 +1,7 @@ --- name: resolve-pr-comments -description: Resolve all unresolved PR comments interactively. Use when asked to resolve PR comments, address review feedback, handle CodeRabbit comments, or fix PR review issues. Invoked with /resolve-pr-comments or /resolve-pr-comments . +description: Resolve all unresolved PR comments interactively. Makes local edits onlyβ€”NEVER commits or pushes. Use when asked to resolve PR comments, address review feedback, handle CodeRabbit comments, or fix PR review issues. Invoked with /resolve-pr-comments or /resolve-pr-comments . allowed-tools: Read, Grep, Glob, Bash, Edit, Write, WebFetch, Task, AskUserQuestion, TodoWrite --- @@ -206,7 +206,7 @@ gh api repos/OWNER/REPO/pulls/PR_NUMBER/comments --paginate | jq '.[] | select(. ## Step 5: Execute Actions -**CRITICAL: Do NOT reply to PR comments until changes are pushed to the remote.** The reviewer cannot verify fixes until the code is pushed. Collect all fixes locally first, then push, then reply. +**CRITICAL: Do NOT reply to PR comments until changes are pushed to the remote.** The reviewer cannot verify fixes until the code is pushed. Collect all fixes locally. This skill NEVER commits or pushesβ€”the user handles that manually. ### For FIX: 1. Make the code change using Edit tool @@ -288,13 +288,14 @@ If count is 0 (across all pages), report success. If comments remain: ## Important Notes -1. **NEVER reply "Fixed" until code is pushed** - The reviewer cannot verify fixes until they're on the remote. Make all fixes locally, push, THEN reply. -2. **Always read the file** before suggesting fixes - understand context -3. **Check for existing replies** in the thread before responding -4. **Wait for user approval** on each action - never auto-fix without confirmation -5. **Update tracking file** after each action -6. **Some bots are slow** - CodeRabbit may take minutes to auto-resolve after push -7. **Push code changes** before expecting auto-resolution of FIX actions +1. **NEVER commit or push changes** - This skill only makes local edits. The user handles `git add`, `git commit`, and `git push` themselves. Do not run any git commit or git push commands. +2. **NEVER reply "Fixed" until code is pushed** - The reviewer cannot verify fixes until they're on the remote. Make all fixes locally. Only reply to FIX comments after the user confirms they have pushed (the user pushes manually). +3. **Always read the file** before suggesting fixes - understand context +4. **Check for existing replies** in the thread before responding +5. **Wait for user approval** on each action - never auto-fix without confirmation +6. **Update tracking file** after each action +7. **Some bots are slow** - CodeRabbit may take minutes to auto-resolve after push +8. **User pushes manually** - This skill never commits or pushes; the user must push code changes before expecting auto-resolution of FIX actions ## Error Handling diff --git a/.github/workflows/configs/default/config.json b/.github/workflows/configs/default/config.json index c16511cbcc..e3ac85b6a7 100644 --- a/.github/workflows/configs/default/config.json +++ b/.github/workflows/configs/default/config.json @@ -31,6 +31,7 @@ "name": "e2e-openai-key", "value": "env.OPENAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -44,6 +45,7 @@ "name": "e2e-anthropic-key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], diff --git a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json index 600267db03..a0122adfa2 100644 --- a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json +++ b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json @@ -88,6 +88,8 @@ "provider_configs": [ { "provider": "openai", + "allowed_models": ["*"], + "key_ids": ["*"], "weight": 1.0 } ] @@ -109,6 +111,8 @@ "provider_configs": [ { "provider": "openai", + "allowed_models": ["*"], + "key_ids": ["*"], "weight": 1.0 } ] @@ -130,7 +134,8 @@ { "name": "openai-primary", "value": "env.OPENAI_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ] } diff --git a/.github/workflows/release-pipeline.yml b/.github/workflows/release-pipeline.yml index 18e6464d0b..fa87164a8b 100644 --- a/.github/workflows/release-pipeline.yml +++ b/.github/workflows/release-pipeline.yml @@ -3,7 +3,7 @@ name: Release Pipeline # Triggers automatically on push to main when any version file changes on: push: - branches: ["main"] + branches: ["main", "v1.5.0"] # Prevent concurrent runs concurrency: diff --git a/.github/workflows/scripts/detect-all-changes.sh b/.github/workflows/scripts/detect-all-changes.sh index ce6345315f..1f395e4107 100755 --- a/.github/workflows/scripts/detect-all-changes.sh +++ b/.github/workflows/scripts/detect-all-changes.sh @@ -47,8 +47,8 @@ else else if [[ "$CORE_VERSION" == *"-"* ]]; then # current_version has prerelease, so include all versions but prefer stable - ALL_TAGS=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + ALL_TAGS=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) if [ -n "$STABLE_TAGS" ]; then # Get the highest stable version @@ -61,7 +61,7 @@ else fi else # VERSION has no prerelease, so only consider stable releases in same track - LATEST_CORE_TAG=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1) + LATEST_CORE_TAG=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1 || true) echo "latest core tag (stable only): $LATEST_CORE_TAG" fi PREVIOUS_CORE_VERSION=${LATEST_CORE_TAG#core/v} @@ -88,17 +88,26 @@ else FRAMEWORK_MAJOR_MINOR=$(echo "$FRAMEWORK_BASE_VERSION" | cut -d. -f1,2) echo " πŸ” Checking track: ${FRAMEWORK_MAJOR_MINOR}.x" - ALL_TAGS=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') - PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) LATEST_FRAMEWORK_TAG="" - if [ -n "$STABLE_TAGS" ]; then - LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) - echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + if [[ "$FRAMEWORK_VERSION" == *"-"* ]]; then + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + else + # No stable versions, get highest prerelease + LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" + fi else - LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) - echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" - fi + # VERSION has no prerelease, so only consider stable releases in same track + LATEST_FRAMEWORK_TAG=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1 || true) + echo "latest framework tag (stable only): $LATEST_FRAMEWORK_TAG" + fi if [ -z "$LATEST_FRAMEWORK_TAG" ]; then echo " βœ… First framework release in track ${FRAMEWORK_MAJOR_MINOR}.x: $FRAMEWORK_VERSION" FRAMEWORK_NEEDS_RELEASE="true" @@ -153,20 +162,20 @@ for plugin_dir in plugins/*/; do echo " πŸ” Checking track: ${plugin_major_minor}.x" if [[ "$current_version" == *"-"* ]]; then - # current_version has prerelease, so include all versions but prefer stable - ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) - PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) - - if [ -n "$STABLE_TAGS" ]; then - # Get the highest stable version - LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) - echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" - else - # No stable versions, get highest prerelease - LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) - echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" - fi + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" + else + # No stable versions, get highest prerelease + LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" + fi else # VERSION has no prerelease, so only consider stable releases in same track LATEST_PLUGIN_TAG=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | grep -v '\-' | sort -V | tail -1 || true) diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index cb660f961c..85f901cfd5 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -133,11 +133,15 @@ cleanup() { } trap cleanup EXIT -# Get previous N transport versions (excluding prereleases) +# Get previous N transport versions (excluding prereleases) plus explicitly tested prereleases get_previous_versions() { local count="${1:-3}" cd "$REPO_ROOT" - git tag -l "transports/v*" | grep -v -- "-" | sort -V | tail -n "$count" | sed 's|transports/||' + local stable + stable=$(git tag -l "transports/v*" | grep -v -- "-" | sort -V | tail -n "$count" | sed 's|transports/||') + # Explicitly include prerelease versions that need migration coverage + local prereleases="v1.5.0-prerelease1" + echo "$stable"$'\n'"$prereleases" | grep -v '^$' | sort -V | uniq } # Wait for bifrost to start @@ -339,6 +343,22 @@ run_postgres_sql() { -c "$sql" 2>/dev/null } +run_postgres_scalar() { + local sql="$1" + + local container + container=$(get_postgres_container) + + if [ -z "$container" ]; then + log_error "PostgreSQL container not found" + return 1 + fi + + docker exec "$container" \ + psql -U "$POSTGRES_USER" -d "$POSTGRES_DB" -t -A \ + -c "$sql" 2>/dev/null | tr -d '[:space:]' +} + run_postgres_sql_file() { local sql_file="$1" @@ -453,10 +473,10 @@ VALUES (1, 'migration-test-hash-abc123def456', $now, $now) ON CONFLICT DO NOTHING; -- governance_budgets (reset_duration is a string like "1d", "1h", etc.) -INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, created_at, updated_at, calendar_aligned) +INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, calendar_aligned, created_at, updated_at) VALUES - ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', $now, $now, 0), - ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', $now, $now, 1) + ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', false, $now, $now), + ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', false, $now, $now) ON CONFLICT DO NOTHING; -- governance_rate_limits (flexible duration format with token_* and request_* columns) @@ -623,12 +643,9 @@ CROSS JOIN config_keys ck WHERE vpc.virtual_key_id = 'vk-migration-test-1' AND ck.name = 'migration-test-key-openai' ON CONFLICT DO NOTHING; --- governance_virtual_key_mcp_configs (references virtual_keys and mcp_clients) --- We need to reference the mcp_client by its internal ID, so use a subquery -INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) -SELECT 'vk-migration-test-1', id, '["tool1"]' -FROM config_mcp_clients WHERE client_id = 'mcp-migration-test-001' -ON CONFLICT DO NOTHING; +-- governance_virtual_key_mcp_configs: handled dynamically after config_mcp_clients is inserted +-- (see generate_mcp_clients_insert_postgres/sqlite) so the subquery finds the MCP client row. +-- Both test VKs are covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. -- sessions (id is auto-increment integer, not a string) INSERT INTO sessions (token, expires_at, created_at, updated_at) @@ -707,6 +724,7 @@ append_dynamic_mcp_clients_insert() { generate_prompt_repo_tables_insert_postgres "$now" "$faker_sql" generate_model_parameters_insert_postgres "$now" "$faker_sql" generate_routing_targets_insert_postgres "$now" "$faker_sql" + generate_pricing_overrides_insert_postgres "$now" "$faker_sql" append_dynamic_columns_postgres "$now" "$past" "$faker_sql" else now="datetime('now')" @@ -717,6 +735,7 @@ append_dynamic_mcp_clients_insert() { generate_prompt_repo_tables_insert_sqlite "$now" "$faker_sql" "$config_db" generate_model_parameters_insert_sqlite "$now" "$faker_sql" "$config_db" generate_routing_targets_insert_sqlite "$now" "$faker_sql" "$config_db" + generate_pricing_overrides_insert_sqlite "$now" "$faker_sql" "$config_db" append_dynamic_columns_sqlite "$now" "$past" "$faker_sql" "$config_db" fi } @@ -822,6 +841,16 @@ append_dynamic_columns_postgres() { echo "UPDATE config_keys SET vllm_model_name = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" fi + # config_keys.ollama_url, sgl_url (added in v1.5.0-prerelease1) + if column_exists_postgres "config_keys" "ollama_url"; then + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + if column_exists_postgres "config_keys" "sgl_url"; then + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + # config_keys.encryption_status (added in v1.4.8) if column_exists_postgres "config_keys" "encryption_status"; then echo "UPDATE config_keys SET encryption_status = 'plain_text' WHERE name = 'migration-test-key-openai';" >> "$output_file" @@ -949,6 +978,17 @@ append_dynamic_columns_postgres() { echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" fi + # logs.image_edit_input, image_variation_input (added in v1.5.0-prerelease1) + if column_exists_postgres "logs" "image_edit_input"; then + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + if column_exists_postgres "logs" "image_variation_input"; then + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi if column_exists_postgres "logs" "video_list_output"; then echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" @@ -1190,6 +1230,61 @@ append_dynamic_columns_postgres() { echo "UPDATE governance_model_pricing SET code_interpreter_cost_per_session = NULL WHERE id = 2;" >> "$output_file" fi + # ------------------------------------------------------------------------- + # v1.5.0 columns - config store tables + # ------------------------------------------------------------------------- + + # config_client.mcp_disable_auto_tool_inject (added in v1.5.0) + if column_exists_postgres "config_client" "mcp_disable_auto_tool_inject"; then + echo "UPDATE config_client SET mcp_disable_auto_tool_inject = false WHERE id = 1;" >> "$output_file" + fi + + # config_client.whitelisted_routes_json (added in v1.5.0) + if column_exists_postgres "config_client" "whitelisted_routes_json"; then + echo "UPDATE config_client SET whitelisted_routes_json = '[]' WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0) + # vk-migration-test-1 has a key in the join table, so old behavior was restricted to that key -> allow_all_keys=false + # vk-migration-test-2 has no key rows, so old "empty=allow-all" semantics -> allow_all_keys=true + if column_exists_postgres "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = false WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = true WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0 columns - log store tables + # ------------------------------------------------------------------------- + + # logs.plugin_logs (added in v1.5.0) + if column_exists_postgres "logs" "plugin_logs"; then + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.4.19 columns + # ------------------------------------------------------------------------- + + # governance_model_pricing: context_length, max_input_tokens, max_output_tokens, architecture (added in v1.4.19, removed later) + if column_exists_postgres "governance_model_pricing" "context_length"; then + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "max_input_tokens"; then + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "max_output_tokens"; then + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "architecture"; then + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 2;" >> "$output_file" + fi + # ------------------------------------------------------------------------- # v1.4.17 columns # ------------------------------------------------------------------------- @@ -1338,6 +1433,16 @@ append_dynamic_columns_sqlite() { echo "UPDATE config_keys SET vllm_model_name = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" fi + # config_keys.ollama_url, sgl_url (added in v1.5.0-prerelease1) + if column_exists_sqlite "$config_db" "config_keys" "ollama_url"; then + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "config_keys" "sgl_url"; then + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + # config_keys.encryption_status (added in v1.4.8) if column_exists_sqlite "$config_db" "config_keys" "encryption_status"; then echo "UPDATE config_keys SET encryption_status = 'plain_text' WHERE name = 'migration-test-key-openai';" >> "$output_file" @@ -1456,6 +1561,17 @@ append_dynamic_columns_sqlite() { echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + # logs.image_edit_input, image_variation_input (added in v1.5.0-prerelease1) + if column_exists_sqlite "$logs_db" "logs" "image_edit_input"; then + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + if column_exists_sqlite "$logs_db" "logs" "image_variation_input"; then + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" @@ -1683,6 +1799,58 @@ append_dynamic_columns_sqlite() { echo "UPDATE logs SET cached_read_tokens = 0 WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET cached_read_tokens = 0 WHERE id = 'log-migration-test-003';" >> "$output_file" + # ------------------------------------------------------------------------- + # v1.5.0 columns - config store tables + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_client.mcp_disable_auto_tool_inject (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_client" "mcp_disable_auto_tool_inject"; then + echo "UPDATE config_client SET mcp_disable_auto_tool_inject = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0) + # vk-migration-test-1 has a key in the join table, so old behavior was restricted to that key -> allow_all_keys=false + # vk-migration-test-2 has no key rows, so old "empty=allow-all" semantics -> allow_all_keys=true + if column_exists_sqlite "$config_db" "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 0 WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 1 WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + fi + + # ------------------------------------------------------------------------- + # v1.5.0 columns - log store tables (emitted unconditionally; fail silently on config_db) + # ------------------------------------------------------------------------- + + # logs.plugin_logs (added in v1.5.0) + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + # ------------------------------------------------------------------------- + # v1.4.19 columns + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # governance_model_pricing: context_length, max_input_tokens, max_output_tokens, architecture (added in v1.4.19, removed later) + if column_exists_sqlite "$config_db" "governance_model_pricing" "context_length"; then + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "max_input_tokens"; then + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "max_output_tokens"; then + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "architecture"; then + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 2;" >> "$output_file" + fi + fi + # ------------------------------------------------------------------------- # v1.4.17 columns # ------------------------------------------------------------------------- @@ -1815,10 +1983,29 @@ generate_mcp_clients_insert_postgres() { vals="$vals, 'plain_text'" fi + # config_mcp_clients.allowed_extra_headers_json (added in v1.5.0) + if column_exists_postgres "config_mcp_clients" "allowed_extra_headers_json"; then + cols="$cols, allowed_extra_headers_json" + vals="$vals, '[]'" + fi + + # config_mcp_clients.allow_on_all_virtual_keys (added in v1.5.0) + if column_exists_postgres "config_mcp_clients" "allow_on_all_virtual_keys"; then + cols="$cols, allow_on_all_virtual_keys" + vals="$vals, false" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" echo "INSERT INTO config_mcp_clients ($cols) VALUES ($vals) ON CONFLICT DO NOTHING;" >> "$output_file" + + # governance_virtual_key_mcp_configs: link both test VKs to the test MCP client. + # Must run AFTER config_mcp_clients INSERT so the subquery finds the row. + # Both VKs covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. + echo "" >> "$output_file" + echo "-- governance_virtual_key_mcp_configs (dynamically generated after config_mcp_clients)" >> "$output_file" + echo "INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) SELECT vk.id, mc.id, '[\"tool1\"]' FROM governance_virtual_keys vk CROSS JOIN config_mcp_clients mc WHERE mc.client_id = 'mcp-migration-test-001' AND vk.id IN ('vk-migration-test-1', 'vk-migration-test-2') ON CONFLICT DO NOTHING;" >> "$output_file" } # Get columns that are auto-increment primary keys (don't need faker coverage) @@ -2029,10 +2216,29 @@ generate_mcp_clients_insert_sqlite() { vals="$vals, 'plain_text'" fi + # config_mcp_clients.allowed_extra_headers_json (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_mcp_clients" "allowed_extra_headers_json"; then + cols="$cols, allowed_extra_headers_json" + vals="$vals, '[]'" + fi + + # config_mcp_clients.allow_on_all_virtual_keys (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_mcp_clients" "allow_on_all_virtual_keys"; then + cols="$cols, allow_on_all_virtual_keys" + vals="$vals, 0" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" echo "INSERT INTO config_mcp_clients ($cols) VALUES ($vals) ON CONFLICT DO NOTHING;" >> "$output_file" + + # governance_virtual_key_mcp_configs: link both test VKs to the test MCP client. + # Must run AFTER config_mcp_clients INSERT so the subquery finds the row. + # Both VKs covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. + echo "" >> "$output_file" + echo "-- governance_virtual_key_mcp_configs (dynamically generated after config_mcp_clients)" >> "$output_file" + echo "INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) SELECT vk.id, mc.id, '[\"tool1\"]' FROM governance_virtual_keys vk CROSS JOIN config_mcp_clients mc WHERE mc.client_id = 'mcp-migration-test-001' AND vk.id IN ('vk-migration-test-1', 'vk-migration-test-2') ON CONFLICT DO NOTHING;" >> "$output_file" } # Generate async_jobs INSERT based on schema existence for PostgreSQL @@ -2261,6 +2467,49 @@ generate_routing_targets_insert_sqlite() { echo "INSERT INTO routing_targets (rule_id, provider, model, key_id, weight) VALUES ('rule-migration-test-2', NULL, NULL, NULL, 0.3) ON CONFLICT DO NOTHING;" >> "$output_file" } +# Generate governance_pricing_overrides INSERT for PostgreSQL +# This table was added in v1.5.0 as part of the custom pricing refactor. +# Two rows: one global (no FK deps) and one virtual_key-scoped (references vk-migration-test-1). +generate_pricing_overrides_insert_postgres() { + local now="$1" + local output_file="$2" + + # Check if the table exists + if ! column_exists_postgres "governance_pricing_overrides" "id"; then + return + fi + + echo "" >> "$output_file" + echo "-- governance_pricing_overrides (scoped pricing overrides - added in v1.5.0, dynamically generated)" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-001', 'Migration Test Override Global', 'global', NULL, NULL, NULL, 'exact', 'gpt-4', '[]', '{\"input_cost_per_token\": 0.00001}', 'po-hash-001', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-002', 'Migration Test Override VK', 'virtual_key', 'vk-migration-test-1', NULL, NULL, 'prefix', 'claude', '[]', '{\"output_cost_per_token\": 0.00002}', 'po-hash-002', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + +# Generate governance_pricing_overrides INSERT for SQLite +# This table was added in v1.5.0 as part of the custom pricing refactor. +generate_pricing_overrides_insert_sqlite() { + local now="$1" + local output_file="$2" + local config_db="$3" + + # Check if the table exists in the database + if [ ! -f "$config_db" ]; then + return + fi + + local table_exists + table_exists=$(sqlite3 "$config_db" "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='governance_pricing_overrides';" 2>/dev/null || echo "0") + + if [ "$table_exists" != "1" ]; then + return + fi + + echo "" >> "$output_file" + echo "-- governance_pricing_overrides (scoped pricing overrides - added in v1.5.0, dynamically generated)" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-001', 'Migration Test Override Global', 'global', NULL, NULL, NULL, 'exact', 'gpt-4', '[]', '{\"input_cost_per_token\": 0.00001}', 'po-hash-001', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-002', 'Migration Test Override VK', 'virtual_key', 'vk-migration-test-1', NULL, NULL, 'prefix', 'claude', '[]', '{\"output_cost_per_token\": 0.00002}', 'po-hash-002', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + # Validate faker column coverage for SQLite validate_faker_column_coverage_sqlite() { local faker_sql="$1" @@ -2479,6 +2728,7 @@ compare_postgres_snapshots() { # - network_config_json, concurrency_buffer_json, proxy_config_json, custom_provider_config_json: # JSON fields that get normalized with default values during migration # - budget_id, rate_limit_id: governance fields that may be reset or initialized during migrations + # - virtual_key_id, provider_config_id: new FK columns on governance_budgets (added by multi-budget migration) # - status, description: key validation runs after migration, updating these fields # for invalid/test keys (e.g., status becomes "list_models_failed") local ignore_columns="updated_at config_hash created_at models_json weight allowed_models network_config_json concurrency_buffer_json proxy_config_json custom_provider_config_json budget_id rate_limit_id status description" @@ -2535,6 +2785,20 @@ compare_postgres_snapshots() { if [ "$table" = "routing_rules" ]; then dropped_columns="$dropped_columns provider model" fi + # azure_deployments_json, vertex_deployments_json, bedrock_deployments_json, replicate_deployments_json + # (dropped from config_keys - migrated to provider-level deployment config) + if [ "$table" = "config_keys" ]; then + dropped_columns="$dropped_columns azure_deployments_json vertex_deployments_json bedrock_deployments_json replicate_deployments_json" + fi + # budget_id (dropped from governance_virtual_keys and governance_virtual_key_provider_configs + # in add_multi_budget_tables - ownership moved to governance_budgets.virtual_key_id/provider_config_id) + if [ "$table" = "governance_virtual_keys" ] || [ "$table" = "governance_virtual_key_provider_configs" ]; then + dropped_columns="$dropped_columns budget_id" + fi + # calendar_aligned (dropped from governance_budgets in add_multi_budget_tables - moved to governance_virtual_keys.calendar_aligned) + if [ "$table" = "governance_budgets" ]; then + dropped_columns="$dropped_columns calendar_aligned" + fi local before_col_array IFS=',' read -ra before_col_array <<< "$before_columns" @@ -2595,7 +2859,12 @@ compare_postgres_snapshots() { local col_idx=1 for col in "${before_col_array[@]}"; do # Skip columns that are expected to change - if [[ " $ignore_columns " == *" $col "* ]]; then + # virtual_key_id, provider_config_id: only ignore on governance_budgets (new FK columns from multi-budget migration) + local table_ignore_columns="$ignore_columns" + if [ "$table" = "governance_budgets" ]; then + table_ignore_columns="$table_ignore_columns virtual_key_id provider_config_id" + fi + if [[ " $table_ignore_columns " == *" $col "* ]]; then col_idx=$((col_idx + 1)) continue fi @@ -2686,6 +2955,84 @@ compare_postgres_snapshots() { # Validation Functions (simplified, uses snapshots) # ============================================================================ +# verify_budget_migration checks that the multi-budget FK migration correctly +# moved budget ownership from VK/ProviderConfig budget_id columns to +# governance_budgets.virtual_key_id / governance_budgets.provider_config_id +verify_budget_migration_postgres() { + log_info "Verifying budget migration (budget_id β†’ virtual_key_id/provider_config_id)..." + local failed=0 + + # Check: budget-migration-test-1 was linked to vk-migration-test-1 via budget_id + # After migration, governance_budgets.virtual_key_id should be set + local vk_budget_count + vk_budget_count=$(run_postgres_scalar "SELECT COUNT(*) FROM governance_budgets WHERE id = 'budget-migration-test-1' AND virtual_key_id = 'vk-migration-test-1'") + if [ "$vk_budget_count" = "1" ]; then + log_info " VK budget migration: budget-migration-test-1 β†’ vk-migration-test-1 βœ“" + else + log_warn " VK budget migration: budget-migration-test-1 virtual_key_id not set (count=$vk_budget_count) β€” may be expected if old version didn't have budget_id on VK" + fi + + # Check: budget-migration-test-2 was linked to provider config via budget_id + # After migration, governance_budgets.provider_config_id should be set + local pc_budget_count + pc_budget_count=$(run_postgres_scalar "SELECT COUNT(*) FROM governance_budgets WHERE id = 'budget-migration-test-2' AND provider_config_id IS NOT NULL") + if [ "$pc_budget_count" = "1" ]; then + log_info " PC budget migration: budget-migration-test-2 β†’ provider_config βœ“" + else + log_warn " PC budget migration: budget-migration-test-2 provider_config_id not set (count=$pc_budget_count) β€” may be expected if old version didn't have budget_id on PC" + fi + + # Check: virtual_key_id and provider_config_id columns exist on governance_budgets + local has_vk_col + has_vk_col=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_budgets' AND column_name = 'virtual_key_id'") + if [ "$has_vk_col" = "1" ]; then + log_info " Column governance_budgets.virtual_key_id exists βœ“" + else + log_error " Column governance_budgets.virtual_key_id MISSING!" + failed=1 + fi + + local has_pc_col + has_pc_col=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_budgets' AND column_name = 'provider_config_id'") + if [ "$has_pc_col" = "1" ]; then + log_info " Column governance_budgets.provider_config_id exists βœ“" + else + log_error " Column governance_budgets.provider_config_id MISSING!" + failed=1 + fi + + # Check: budget_id column should be dropped from governance_virtual_keys + local vk_has_budget_id + vk_has_budget_id=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_virtual_keys' AND column_name = 'budget_id'") + if [ "$vk_has_budget_id" = "0" ]; then + log_info " Column governance_virtual_keys.budget_id dropped βœ“" + else + log_error " Column governance_virtual_keys.budget_id still exists!" + failed=1 + fi + + # Check: budget_id column should be dropped from governance_virtual_key_provider_configs + local pc_has_budget_id + pc_has_budget_id=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_virtual_key_provider_configs' AND column_name = 'budget_id'") + if [ "$pc_has_budget_id" = "0" ]; then + log_info " Column governance_virtual_key_provider_configs.budget_id dropped βœ“" + else + log_error " Column governance_virtual_key_provider_configs.budget_id still exists!" + failed=1 + fi + + # Check: junction tables should not exist + local junction_vk + junction_vk=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'governance_virtual_key_budgets'") + if [ "$junction_vk" = "0" ]; then + log_info " Junction table governance_virtual_key_budgets dropped βœ“" + else + log_warn " Junction table governance_virtual_key_budgets still exists (may not have existed in old version)" + fi + + return $failed +} + validate_postgres_data() { local before_snapshot="$1" local after_snapshot="$2" @@ -2935,6 +3282,13 @@ EOF return 1 fi + # STEP 6: Verify budget migration (budget_id β†’ virtual_key_id/provider_config_id) + if ! verify_budget_migration_postgres; then + log_error "Budget migration verification failed after migration from $version" + stop_bifrost + return 1 + fi + stop_bifrost log_info "Migration from $version: SUCCESS" done diff --git a/.github/workflows/scripts/setup-go-workspace.sh b/.github/workflows/scripts/setup-go-workspace.sh index a5effd3c49..dbc0165f14 100755 --- a/.github/workflows/scripts/setup-go-workspace.sh +++ b/.github/workflows/scripts/setup-go-workspace.sh @@ -27,7 +27,9 @@ go work use ./plugins/logging go work use ./plugins/maxim go work use ./plugins/mocker go work use ./plugins/otel +go work use ./plugins/prompts go work use ./plugins/semanticcache go work use ./plugins/telemetry go work use ./transports +go work use ./cli echo "βœ… Go workspace initialized" diff --git a/core/bifrost.go b/core/bifrost.go index c705c6cffd..96c206eb8b 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5,6 +5,7 @@ package bifrost import ( "context" + "errors" "fmt" "math/rand" "slices" @@ -153,6 +154,9 @@ type PluginPipeline struct { postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) chunkCount int + + // Plugin logging: cached scoped contexts for streaming post-hooks (reused across chunks) + streamScopedCtxs map[string]*schemas.BifrostContext } // pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks @@ -592,9 +596,10 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * Message: "prompt not provided for text completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -631,9 +636,10 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "text not provided for text completion stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -662,9 +668,10 @@ func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, r Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -727,9 +734,10 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -763,9 +771,10 @@ func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *s Message: "responses not provided for responses request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -831,9 +840,10 @@ func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req Message: "responses not provided for responses stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -866,9 +876,10 @@ func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *sch Message: "input not provided for count tokens request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.CountTokensRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.CountTokensRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -898,16 +909,19 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem }, } } - if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !isLargePayloadPassthrough(ctx) { + hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil && + (req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil) + if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "embedding input not provided for embedding request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.EmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -944,9 +958,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "query not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -957,9 +972,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "documents not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -971,9 +987,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: fmt.Sprintf("document text is empty at index %d", i), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1009,9 +1026,10 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. Message: "speech input not provided for speech request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1048,9 +1066,10 @@ func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *sc Message: "speech input not provided for speech stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1082,9 +1101,10 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s Message: "transcription input not provided for transcription request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1121,9 +1141,10 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, Message: "transcription input not provided for transcription stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1156,9 +1177,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for image generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1178,9 +1200,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1209,9 +1232,10 @@ func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext Message: "prompt not provided for image generation stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1243,14 +1267,19 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "images not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } - // Prompt is not required when type is background_removal - if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") && + // Prompt is not required for certain operation types that work without a text prompt + var imageEditParamsType *string + if req.Params != nil { + imageEditParamsType = req.Params.Type + } + if !isPromptOptionalImageEditType(imageEditParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1258,9 +1287,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "prompt not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1281,9 +1311,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1311,14 +1342,19 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "images not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } - // Prompt is not required when type is background_removal - if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") && + // Prompt is not required for certain operation types that work without a text prompt + var imageEditStreamParamsType *string + if req.Params != nil { + imageEditStreamParamsType = req.Params.Type + } + if !isPromptOptionalImageEditType(imageEditStreamParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1326,9 +1362,10 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "prompt not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1360,9 +1397,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "image not provided for image variation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1383,9 +1421,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1414,9 +1453,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for video generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1436,9 +1476,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -3193,9 +3234,10 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error Message: "request failed during provider concurrency update", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: m.RequestType, - Provider: provider, - ModelRequested: model, + RequestType: m.RequestType, + Provider: provider, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, }: case <-time.After(1 * time.Second): @@ -3348,7 +3390,7 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn // }, toolSchema) func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RegisterTool(name, description, handler, toolSchema) @@ -3366,7 +3408,7 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // - error: Any retrieval error func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { if bifrost.MCPManager == nil { - return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") + return nil, fmt.Errorf("mcp is not configured in this bifrost instance") } clients := bifrost.MCPManager.GetClients() @@ -3406,7 +3448,7 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // // Returns: // - []schemas.ChatTool: List of available tools -func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { +func (bifrost *Bifrost) GetAvailableMCPTools(ctx *schemas.BifrostContext) []schemas.ChatTool { if bifrost.MCPManager == nil { return nil } @@ -3477,7 +3519,7 @@ func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error { // } func (bifrost *Bifrost) RemoveMCPClient(id string) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RemoveClient(id) @@ -3485,11 +3527,31 @@ func (bifrost *Bifrost) RemoveMCPClient(id string) error { // SetMCPManager sets the MCP manager for this Bifrost instance. // This allows injecting a custom MCP manager implementation (e.g., for enterprise features). +// If the provided manager is a concrete *mcp.MCPManager, Bifrost's plugin pipeline is injected +// into the manager's CodeMode so that nested tool calls run through the plugin hooks. // // Parameters: // - manager: The MCP manager to set (must implement MCPManagerInterface) func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { bifrost.MCPManager = manager + // Inject Bifrost's plugin pipeline into the manager's CodeMode so that + // nested tool calls (e.g. via Starlark executeCode) run through plugin hooks. + if m, ok := manager.(*mcp.MCPManager); ok { + m.SetPluginPipeline( + func() mcp.PluginPipeline { + pipeline := bifrost.getPluginPipeline() + if pp, ok := any(pipeline).(mcp.PluginPipeline); ok { + return pp + } + return nil + }, + func(pipeline mcp.PluginPipeline) { + if pp, ok := pipeline.(*PluginPipeline); ok { + bifrost.releasePluginPipeline(pp) + } + }, + ) + } } // UpdateMCPClient updates the MCP client. @@ -3510,7 +3572,7 @@ func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { // }) func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.UpdateClient(id, updatedConfig) @@ -3525,23 +3587,63 @@ func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPCli // - error: Any reconnection error func (bifrost *Bifrost) ReconnectMCPClient(id string) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.ReconnectClient(id) } +// VerifyPerUserOAuthConnection delegates to the MCP manager to verify an MCP +// server using a temporary access token and discover available tools. The +// connection is closed after verification. If the MCP manager is not yet +// initialized, it is lazily created (same as AddMCPClient). +func (bifrost *Bifrost) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + // Ensure MCP manager is initialized (lazy init, same pattern as AddMCPClient) + if bifrost.MCPManager == nil { + bifrost.mcpInitOnce.Do(func() { + mcpConfig := schemas.MCPConfig{ + ClientConfigs: []*schemas.MCPClientConfig{}, + } + mcpConfig.PluginPipelineProvider = func() interface{} { + return bifrost.getPluginPipeline() + } + mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) { + if pp, ok := pipeline.(*PluginPipeline); ok { + bifrost.releasePluginPipeline(pp) + } + } + codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger) + bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) + }) + } + if bifrost.MCPManager == nil { + return nil, nil, fmt.Errorf("MCP manager is not initialized") + } + return bifrost.MCPManager.VerifyPerUserOAuthConnection(ctx, config, accessToken) +} + +// SetClientTools delegates to the MCP manager to update the tool map for an +// existing MCP client. +func (bifrost *Bifrost) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + if bifrost.MCPManager != nil { + bifrost.MCPManager.SetClientTools(clientID, tools, toolNameMapping) + } +} + // UpdateToolManagerConfig updates the tool manager config for the MCP manager. // This allows for hot-reloading of the tool manager config at runtime. -func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { +// Pass the current value of disableAutoToolInject whenever only other fields +// change so the flag is never silently reset to its zero value. +func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } bifrost.MCPManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ - MaxAgentDepth: maxAgentDepth, - ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, - CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + MaxAgentDepth: maxAgentDepth, + ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + DisableAutoToolInject: disableAutoToolInject, }) return nil } @@ -3724,9 +3826,10 @@ func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) sche return bifrost.getProviderByKey(providerKey) } -// SelectKeyForProvider selects an API key for the given provider and model. -// Used by WebSocket handlers that need a key for upstream connections. -func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { +// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model. +// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific +// AllowedRequests gates such as realtime-only support. +func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { if ctx == nil { ctx = bifrost.ctx } @@ -3735,7 +3838,7 @@ func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, provid config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" { baseProvider = config.CustomProviderConfig.BaseProviderType } - return bifrost.selectKeyFromProviderForModel(ctx, schemas.WebSocketResponsesRequest, providerKey, model, baseProvider) + return bifrost.selectKeyFromProviderForModel(ctx, requestType, providerKey, model, baseProvider) } // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. @@ -3749,6 +3852,13 @@ type WSStreamHooks struct { ShortCircuitResponse *schemas.BifrostResponse } +// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a +// single realtime turn rather than one long-lived transport connection. +type RealtimeTurnHooks struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() +} + // RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks, // and returns a PostHookRunner for per-chunk post-processing. // Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks. @@ -3787,12 +3897,22 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() - return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3801,6 +3921,10 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3812,8 +3936,21 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } } + wsProvider, wsModel, _ := preReq.GetRequestFields() postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - return pipeline.RunPostLLMHooks(ctx, result, err, preCount) + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + if result != nil { + result.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + return resp, bifrostErr } return &WSStreamHooks{ @@ -3822,6 +3959,94 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche }, nil } +// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for +// a single realtime turn. Unlike generic stream hooks, realtime turns do not +// support short-circuit responses in v1 because the transports cannot yet emit a +// fully synthetic assistant turn without an upstream generation. +func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) { + if req == nil { + bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil") + bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest + return nil, bifrostErr + } + if ctx == nil { + ctx = bifrost.ctx + } + + if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { + ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) + } + + tracer := bifrost.getTracer() + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + + if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { + traceID := tracer.CreateTrace("") + if traceID != "" { + ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) + } + } + + pipeline := bifrost.getPluginPipeline() + cleanup := func() { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.CleanupStreamAccumulator(traceID) + } + bifrost.releasePluginPipeline(pipeline) + } + provider, model, _ := req.GetRequestFields() + + preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) + if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + if shortCircuit != nil { + if shortCircuit.Error != nil { + shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + if bifrostErr != nil { + return nil, bifrostErr + } + return nil, shortCircuit.Error + } + if shortCircuit.Response != nil { + // Short-circuit responses are not supported for realtime turns (v1). + // Treat this like an error turn so plugins can close pending state cleanly. + bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + } + + return &RealtimeTurnHooks{ + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + drainAndAttachPluginLogs(ctx) + return resp, bifrostErr + }, + Cleanup: cleanup, + }, nil +} + // getProviderByKey retrieves a provider instance from the providers array by its provider key. // Returns the provider if found, or nil if no provider with the given key exists. func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { @@ -4018,11 +4243,7 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) return nil, err } @@ -4055,16 +4276,6 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4107,29 +4318,10 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4144,11 +4336,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) err.StatusCode = schemas.Ptr(fasthttp.StatusBadRequest) return nil, err } @@ -4170,16 +4358,6 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4220,29 +4398,10 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4254,11 +4413,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4269,7 +4424,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling requestHandler, so streaming goroutines @@ -4286,7 +4443,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // Handle short-circuit with response (success case) if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4294,7 +4453,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // Handle short-circuit with error if shortCircuit.Error != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4302,11 +4463,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4318,11 +4475,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4333,36 +4486,26 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4371,15 +4514,13 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4389,6 +4530,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif select { case result = <-msg.Response: resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount) + drainAndAttachPluginLogs(msg.Context) if bifrostErr != nil { bifrost.releaseChannelMessage(msg) return nil, bifrostErr @@ -4405,6 +4547,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case bifrostErrVal := <-msg.Err: bifrostErrPtr := &bifrostErrVal resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount) + drainAndAttachPluginLogs(msg.Context) bifrost.releaseChannelMessage(msg) // Drop raw request/response on error path too if drop, ok := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool); ok && drop { @@ -4425,7 +4568,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-ctx.Done(): bifrost.releaseChannelMessage(msg) provider, model, _ := req.GetRequestFields() - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "waiting for provider response") + bifrostErr := newBifrostCtxDoneError(ctx, "waiting for provider response") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4436,11 +4581,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4451,7 +4592,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling RunLLMPreHooks, so plugins and streaming goroutines @@ -4472,14 +4615,21 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } pipeline := bifrost.getPluginPipeline() - defer bifrost.releasePluginPipeline(pipeline) + releasePipeline := true + defer func() { + if releasePipeline { + bifrost.releasePluginPipeline(pipeline) + } + }() preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4487,13 +4637,23 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Handle short-circuit with stream if shortCircuit.Stream != nil { outputStream := make(chan *schemas.BifrostStreamChunk) + releasePipeline = false // pipeline is released inside the goroutine after stream drains // Create a post hook runner cause pipeline object is put back in the pool on defer pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - return pipeline.RunPostLLMHooks(ctx, result, err, preCount) + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + return resp, bifrostErr } go func() { + defer func() { + drainAndAttachPluginLogs(ctx) // ensure logs are drained even if stream closes without a final chunk + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }() defer close(outputStream) for streamMsg := range shortCircuit.Stream { @@ -4541,7 +4701,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Handle short-circuit with error if shortCircuit.Error != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4549,11 +4711,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4565,11 +4723,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4580,36 +4734,26 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4618,15 +4762,13 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4644,6 +4786,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // On error we will complete post-hooks recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load())) + drainAndAttachPluginLogs(ctx) bifrost.releaseChannelMessage(msg) if recoveredErr != nil { return nil, recoveredErr @@ -4801,7 +4944,7 @@ func executeRequestWithRetries[T any]( } else { // Populate LLM response attributes for non-streaming responses if resp, ok := any(result).(*schemas.BifrostResponse); ok { - tracer.PopulateLLMResponseAttributes(handle, resp, bifrostError) + tracer.PopulateLLMResponseAttributes(ctx, handle, resp, bifrostError) } // End span with appropriate status @@ -4907,7 +5050,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas key := schemas.Key{} var keys []schemas.Key - if providerRequiresKey(baseProvider, config.CustomProviderConfig) { + if providerRequiresKey(config.CustomProviderConfig) { // ListModels needs all enabled/supported keys so providers can aggregate // and report per-key statuses (KeyStatuses). if req.RequestType == schemas.ListModelsRequest { @@ -4921,9 +5064,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -4950,9 +5094,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -4977,9 +5122,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -4994,15 +5140,38 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } } + + originalModelRequested := model + resolvedModel := key.Aliases.Resolve(model) + + // Note: This mutates only the worker's local copy (ChannelMessage.BifrostRequest). + // Key selection already used the original alias. We also record both original and + // resolved values in ExtraFields. + req.SetModel(resolvedModel) + // Create plugin pipeline for streaming requests outside retry loop to prevent leaks var postHookRunner schemas.PostHookRunner var pipeline *PluginPipeline if IsStreamRequestType(req.RequestType) { pipeline = bifrost.getPluginPipeline() postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) return nil, bifrostErr + } else if resp != nil { + resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) } return resp, nil } @@ -5035,14 +5204,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } if bifrostError != nil { - bifrostError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, - RawRequest: bifrostError.ExtraFields.RawRequest, - RawResponse: bifrostError.ExtraFields.RawResponse, - KeyStatuses: bifrostError.ExtraFields.KeyStatuses, - } + bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) // Send error with context awareness to prevent deadlock select { @@ -5056,6 +5218,9 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") } } else { + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } if IsStreamRequestType(req.RequestType) { // Send stream with context awareness to prevent deadlock select { @@ -5146,6 +5311,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch if bifrostError != nil { return nil, bifrostError } + transcriptionResponse.BackfillParams(req.BifrostRequest.TranscriptionRequest) response.TranscriptionResponse = transcriptionResponse case schemas.ImageGenerationRequest: imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest) @@ -5339,9 +5505,10 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, req *Ch Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5375,9 +5542,10 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5399,7 +5567,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ - Message: "MCP is not configured in this Bifrost instance", + Message: "mcp is not configured in this bifrost instance", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, @@ -5424,6 +5592,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR // Handle short-circuit with response (success case) if shortCircuit.Response != nil { finalMcpResp, bifrostErr := pipeline.RunMCPPostHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } @@ -5433,6 +5602,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR if shortCircuit.Error != nil { // Capture post-hook results to respect transformations or recovery finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) // Return post-hook error if present (post-hook may have transformed the error) if finalErr != nil { return nil, finalErr @@ -5475,6 +5645,11 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR RequestType: requestType, }, } + // Preserve MCPUserOAuthRequiredError for downstream detection in agent mode + var oauthErr *schemas.MCPUserOAuthRequiredError + if errors.As(err, &oauthErr) { + bifrostErr.ExtraFields.MCPAuthRequired = oauthErr + } } else if result == nil { bifrostErr = &schemas.BifrostError{ IsBifrostError: false, @@ -5492,6 +5667,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR // Run post-hooks finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, mcpResp, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) if finalErr != nil { return nil, finalErr @@ -5527,6 +5703,9 @@ func (bifrost *Bifrost) executeMCPToolWithHooks(ctx *schemas.BifrostContext, req resp, bifrostErr := bifrost.handleMCPToolExecution(ctx, request, requestType) if bifrostErr != nil { + if bifrostErr.ExtraFields.MCPAuthRequired != nil { + return nil, bifrostErr.ExtraFields.MCPAuthRequired + } return nil, fmt.Errorf("%s", GetErrorMessage(bifrostErr)) } return resp, nil @@ -5556,7 +5735,9 @@ func (p *PluginPipeline) RunLLMPreHooks(ctx *schemas.BifrostContext, req *schema } } - req, shortCircuit, err = plugin.PreLLMHook(ctx, req) + pluginCtx := ctx.WithPluginScope(&pluginName) + req, shortCircuit, err = plugin.PreLLMHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5596,8 +5777,10 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche if runFrom > len(p.llmPlugins) { runFrom = len(p.llmPlugins) } - // Detect streaming mode - if StreamStartTime is set, we're in a streaming context - isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil + requestType, _, _, _ := GetResponseFields(resp, bifrostErr) + // Realtime turns carry StreamStartTime for plugin latency/final-chunk context, + // but they are finalized as one completed turn, not chunk-by-chunk stream output. + isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error @@ -5607,8 +5790,17 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche p.logger.Debug("running post-hook for plugin %s", pluginName) if isStreaming { // For streaming: accumulate timing, don't create individual spans per chunk + // Lazily create cached scoped contexts on first chunk (reused across all chunks) + if p.streamScopedCtxs == nil { + p.streamScopedCtxs = make(map[string]*schemas.BifrostContext, len(p.llmPlugins)) + for _, pl := range p.llmPlugins { + name := pl.GetName() + p.streamScopedCtxs[name] = ctx.WithPluginScope(&name) + } + } + pluginCtx := p.streamScopedCtxs[pluginName] start := time.Now() - resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr) + resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) duration := time.Since(start) p.accumulatePluginTiming(pluginName, duration, err != nil) @@ -5625,7 +5817,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } - resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr) + pluginCtx := ctx.WithPluginScope(&pluginName) + resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) @@ -5679,7 +5873,9 @@ func (p *PluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schema } } - req, shortCircuit, err = plugin.PreMCPHook(ctx, req) + pluginCtx := ctx.WithPluginScope(&pluginName) + req, shortCircuit, err = plugin.PreMCPHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5734,7 +5930,9 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s } } - mcpResp, bifrostErr, err = plugin.PostMCPHook(ctx, mcpResp, bifrostErr) + pluginCtx := ctx.WithPluginScope(&pluginName) + mcpResp, bifrostErr, err = plugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5760,7 +5958,11 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s return mcpResp, nil } -// resetPluginPipeline resets a PluginPipeline instance for reuse +// resetPluginPipeline resets a PluginPipeline instance for reuse. +// IMPORTANT: drainAndAttachPluginLogs must be called on the root BifrostContext +// BEFORE this method, because it calls ReleasePluginScope on cached scoped contexts +// which nils out their pluginLogs pointer. The drain reads from the shared store +// on the root context, so it must happen while the store is still referenced. func (p *PluginPipeline) resetPluginPipeline() { p.executedPreHooks = 0 p.preHookErrors = p.preHookErrors[:0] @@ -5771,6 +5973,25 @@ func (p *PluginPipeline) resetPluginPipeline() { clear(p.postHookTimings) } p.postHookPluginOrder = p.postHookPluginOrder[:0] + // Release cached scoped contexts for streaming + for _, scopedCtx := range p.streamScopedCtxs { + scopedCtx.ReleasePluginScope() + } + p.streamScopedCtxs = nil +} + +// drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext +// and attaches them to the trace for later retrieval by observability plugins. +func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) { + tracer, traceID, err := GetTracerFromContext(ctx) + if err != nil || tracer == nil || traceID == "" { + return + } + logs := ctx.DrainPluginLogs() + if len(logs) == 0 { + return + } + tracer.AttachPluginLogs(traceID, logs) } // accumulatePluginTiming accumulates timing for a plugin during streaming @@ -5862,7 +6083,9 @@ func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { return pipeline } -// releasePluginPipeline returns a PluginPipeline to the pool +// releasePluginPipeline returns a PluginPipeline to the pool. +// Caller must ensure drainAndAttachPluginLogs has already been called on the +// associated BifrostContext before calling this method. func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) { pipeline.resetPluginPipeline() bifrost.pluginPipelinePool.Put(pipeline) @@ -6087,13 +6310,12 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p // - If model is nil or empty β†’ include all keys (no model filter) // - If model is specified: // - If model is in key.BlacklistedModels β†’ exclude (wins over Models allow list) - // - If key.Models is empty β†’ include key (supports all non-blacklisted models) + // - If key.Models is ["*"] β†’ include key (supports all non-blacklisted models) + // - If key.Models is empty β†’ exclude key (deny-by-default) // - If key.Models is non-empty β†’ only include if model is in list + // Blacklist wins over allowlist if model != nil && *model != "" { - if len(k.BlacklistedModels) > 0 && slices.Contains(k.BlacklistedModels, *model) { - continue - } - if len(k.Models) > 0 && !slices.Contains(k.Models, *model) { + if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) { continue } } @@ -6170,66 +6392,53 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex skipModelCheck := (model == "" && (isFileRequestType(requestType) || isBatchRequestType(requestType) || isContainerRequestType(requestType) || isModellessVideoRequestType(requestType) || isPassthroughRequestType(requestType))) || requestType == schemas.ListModelsRequest if skipModelCheck { // When skipping model check: just verify keys are enabled and have values - for _, k := range keys { + for _, key := range keys { // Skip disabled keys - if k.Enabled != nil && !*k.Enabled { + if key.Enabled != nil && !*key.Enabled { continue } - if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { - supportedKeys = append(supportedKeys, k) + isKeyValid := validateKey(providerKey, &key) + if !isKeyValid { + bifrost.logger.Warn("key %s is not valid for provider: %s", key.ID, providerKey) + continue + } + if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, key) } } } else { - // When NOT skipping model check: do full model/deployment filtering + // When NOT skipping model check: do full model filtering for _, key := range keys { // Skip disabled keys if key.Enabled != nil && !*key.Enabled { continue } - hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) - var modelSupported bool - if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) { - modelSupported = false - } else { - modelSupported = (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue) + isKeyValid := validateKey(providerKey, &key) + if !isKeyValid { + bifrost.logger.Warn("key %s is not valid for provider: %s", key.ID, providerKey) + continue } - // Additional deployment checks for Azure, Bedrock and Vertex - deploymentSupported := true - if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { - // For Azure, check if deployment exists for this model - if len(key.AzureKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.AzureKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { - // For Bedrock, check if deployment exists for this model - if len(key.BedrockKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Vertex && key.VertexKeyConfig != nil { - // For Vertex, check if deployment exists for this model - if len(key.VertexKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.VertexKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Replicate && key.ReplicateKeyConfig != nil { - // For Replicate, check if deployment exists for this model - if len(key.ReplicateKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.ReplicateKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { + hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) + // ["*"] = allow all models; [] = deny all; specific list = allow only listed + // NOTE: Model filtering uses the original requested model (which may be an alias). + // key.Models and key.BlacklistedModels must therefore be expressed in alias keys. + // The provider-specific identifier is resolved later in requestWorker via key.Aliases.Resolve(model). + modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model) + if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { // For VLLM, check if model name matches the key's configured model if key.VLLMKeyConfig.ModelName != "" { - deploymentSupported = (key.VLLMKeyConfig.ModelName == model) + modelSupported = modelSupported && (key.VLLMKeyConfig.ModelName == model) } } - if modelSupported && deploymentSupported { + if modelSupported { supportedKeys = append(supportedKeys, key) } } } if len(supportedKeys) == 0 { if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock || baseProviderType == schemas.Vertex || baseProviderType == schemas.Replicate || baseProviderType == schemas.VLLM { - return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model) + return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) } return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) } diff --git a/core/bifrost_test.go b/core/bifrost_test.go index 642ea4c64c..74bd3639a3 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -760,8 +760,8 @@ func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { account.AddProvider(schemas.OpenAI, 5, 1000) // Use 2 keys so we hit the keySelector path (single key returns early) account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ - {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Weight: 1}, - {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Weight: 1}, + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, }) var keySelectorCalls int @@ -821,8 +821,8 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ - {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Weight: 1}, - {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Weight: 1}, + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, }) var keySelectorCalls int @@ -907,7 +907,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { t.Run("second key used when first blacklists", func(t *testing.T) { account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, - {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1}, + {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1, Models: []string{"*"}}, }) key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { @@ -1242,4 +1242,3 @@ func TestUpdateProvider_ProviderSliceIntegrity(t *testing.T) { } }) } - diff --git a/core/changelog.md b/core/changelog.md index e69de29bb2..925afec240 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -0,0 +1,19 @@ +- feat: add Fireworks AI as a first-class provider (thanks [@ivanetchart](https://github.com/ivanetchart)!) +- feat: add realtime provider interfaces, schemas, and engine hooks +- feat: add session log storage and realtime request normalization +- feat: add per-user OAuth consent flow with identity selection and MCP authentication +- feat: add IsSet method to EnvVar and improve provider auth validation +- feat: add support for tracking userId, teamId, customerId, and businessUnitId +- feat: add prompts plugin with direct key header resolver +- feat: add embeddings, image gen, edit and variation to bedrock +- feat: allow path whitelisting from security config +- fix: auto-redact env-backed values in EnvVar JSON serialization +- fix: bedrock tool choice conversion to auto +- fix: MCP tool logs not being captured correctly +- fix: preserve explicit empty tool parameter schemas for openai passthrough +- fix: correct SigV4 service name for bedrock agent runtime rerank +- fix: include raw model ID in list-models output alongside aliases +- fix: vertex endpoint correction +- fix: bedrock streaming retry for retryable AWS exceptions and stale connections +- fix: thinking budget validation for gemini models +- fix: add empty arguments guard in bedrock utils diff --git a/core/internal/llmtests/account.go b/core/internal/llmtests/account.go index ac850830fd..696456a011 100644 --- a/core/internal/llmtests/account.go +++ b/core/internal/llmtests/account.go @@ -20,75 +20,75 @@ const ProviderOpenAICustom = schemas.ModelProvider("openai-custom") // TestScenarios defines the comprehensive test scenarios type TestScenarios struct { - TextCompletion bool - TextCompletionStream bool - SimpleChat bool - CompletionStream bool - MultiTurnConversation bool - ToolCalls bool - ToolCallsStreaming bool // Streaming tool calls functionality + TextCompletion bool + TextCompletionStream bool + SimpleChat bool + CompletionStream bool + MultiTurnConversation bool + ToolCalls bool + ToolCallsStreaming bool // Streaming tool calls functionality MultipleToolCalls bool MultipleToolCallsStreaming bool // Streaming multiple tool calls (some providers only return 1 tool call in streaming) End2EndToolCalling bool - AutomaticFunctionCall bool - ImageURL bool - ImageBase64 bool - MultipleImages bool - FileBase64 bool - FileURL bool - CompleteEnd2End bool - SpeechSynthesis bool // Text-to-speech functionality - SpeechSynthesisStream bool // Streaming text-to-speech functionality - Transcription bool // Speech-to-text functionality - TranscriptionStream bool // Streaming speech-to-text functionality - Embedding bool // Embedding functionality - Reasoning bool // Reasoning/thinking functionality via Responses API - PromptCaching bool // Prompt caching functionality - ListModels bool // List available models functionality - ImageGeneration bool // Image generation functionality - ImageGenerationStream bool // Streaming image generation functionality - ImageEdit bool // Image edit functionality - ImageEditStream bool // Streaming image edit functionality - ImageVariation bool // Image variation functionality - ImageVariationStream bool // Streaming image variation functionality (if supported) - VideoGeneration bool // Video generation functionality - VideoRetrieve bool // Video retrieve functionality - VideoRemix bool // Video remix functionality (OpenAI only) - VideoDownload bool // Video download functionality - VideoList bool // Video list functionality - VideoDelete bool // Video delete functionality - BatchCreate bool // Batch API create functionality - BatchList bool // Batch API list functionality - BatchRetrieve bool // Batch API retrieve functionality - BatchCancel bool // Batch API cancel functionality - BatchResults bool // Batch API results functionality - FileUpload bool // File API upload functionality - FileList bool // File API list functionality - FileRetrieve bool // File API retrieve functionality - FileDelete bool // File API delete functionality - FileContent bool // File API content download functionality - FileBatchInput bool // Whether batch create supports file-based input (InputFileID) - CountTokens bool // Count tokens functionality - ChatAudio bool // Chat completion with audio input/output functionality - StructuredOutputs bool // Structured outputs (JSON schema) functionality - WebSearchTool bool // Web search tool functionality - ContainerCreate bool // Container API create functionality - ContainerList bool // Container API list functionality - ContainerRetrieve bool // Container API retrieve functionality - ContainerDelete bool // Container API delete functionality - ContainerFileCreate bool // Container File API create functionality - ContainerFileList bool // Container File API list functionality - ContainerFileRetrieve bool // Container File API retrieve functionality - ContainerFileContent bool // Container File API content functionality - ContainerFileDelete bool // Container File API delete functionality - PassThroughExtraParams bool // Pass through extra params functionality - Rerank bool // Rerank functionality - PassthroughAPI bool // Raw HTTP passthrough API (Passthrough + PassthroughStream) - WebSocketResponses bool // WebSocket Responses API mode - Realtime bool // Realtime API (bidirectional audio/text) - Compaction bool // Server-side compaction (context management) - InterleavedThinking bool // Interleaved thinking between tool calls (beta) - FastMode bool // Fast mode for Opus 4.6 (beta: research preview) + AutomaticFunctionCall bool + ImageURL bool + ImageBase64 bool + MultipleImages bool + FileBase64 bool + FileURL bool + CompleteEnd2End bool + SpeechSynthesis bool // Text-to-speech functionality + SpeechSynthesisStream bool // Streaming text-to-speech functionality + Transcription bool // Speech-to-text functionality + TranscriptionStream bool // Streaming speech-to-text functionality + Embedding bool // Embedding functionality + Reasoning bool // Reasoning/thinking functionality via Responses API + PromptCaching bool // Prompt caching functionality + ListModels bool // List available models functionality + ImageGeneration bool // Image generation functionality + ImageGenerationStream bool // Streaming image generation functionality + ImageEdit bool // Image edit functionality + ImageEditStream bool // Streaming image edit functionality + ImageVariation bool // Image variation functionality + ImageVariationStream bool // Streaming image variation functionality (if supported) + VideoGeneration bool // Video generation functionality + VideoRetrieve bool // Video retrieve functionality + VideoRemix bool // Video remix functionality (OpenAI only) + VideoDownload bool // Video download functionality + VideoList bool // Video list functionality + VideoDelete bool // Video delete functionality + BatchCreate bool // Batch API create functionality + BatchList bool // Batch API list functionality + BatchRetrieve bool // Batch API retrieve functionality + BatchCancel bool // Batch API cancel functionality + BatchResults bool // Batch API results functionality + FileUpload bool // File API upload functionality + FileList bool // File API list functionality + FileRetrieve bool // File API retrieve functionality + FileDelete bool // File API delete functionality + FileContent bool // File API content download functionality + FileBatchInput bool // Whether batch create supports file-based input (InputFileID) + CountTokens bool // Count tokens functionality + ChatAudio bool // Chat completion with audio input/output functionality + StructuredOutputs bool // Structured outputs (JSON schema) functionality + WebSearchTool bool // Web search tool functionality + ContainerCreate bool // Container API create functionality + ContainerList bool // Container API list functionality + ContainerRetrieve bool // Container API retrieve functionality + ContainerDelete bool // Container API delete functionality + ContainerFileCreate bool // Container File API create functionality + ContainerFileList bool // Container File API list functionality + ContainerFileRetrieve bool // Container File API retrieve functionality + ContainerFileContent bool // Container File API content functionality + ContainerFileDelete bool // Container File API delete functionality + PassThroughExtraParams bool // Pass through extra params functionality + Rerank bool // Rerank functionality + PassthroughAPI bool // Raw HTTP passthrough API (Passthrough + PassthroughStream) + WebSocketResponses bool // WebSocket Responses API mode + Realtime bool // Realtime API (bidirectional audio/text) + Compaction bool // Server-side compaction (context management) + InterleavedThinking bool // Interleaved thinking between tool calls (beta) + FastMode bool // Fast mode for Opus 4.6 (beta: research preview) } // ComprehensiveTestConfig extends TestConfig with additional scenarios @@ -180,7 +180,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -189,7 +189,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), // Use GROQ API key for OpenAI-compatible endpoint - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -198,7 +198,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.ANTHROPIC_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -206,38 +206,38 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, case schemas.Bedrock: return []schemas.Key{ { - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_ARN"), - Deployments: map[string]string{ - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, }, { - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_BEDROCK_ARN"), - Deployments: map[string]string{ - "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -256,7 +256,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.COHERE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -265,20 +265,20 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o", + "gpt-4o-backup": "gpt-4o-3", + "claude-opus-4-5": "claude-opus-4-5", + "o1": "o1", + "gpt-image-1": "gpt-image-1", + "text-embedding-ada-002": "text-embedding-ada-002", + "sora-2": "sora-2", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), - APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o", - "gpt-4o-backup": "gpt-4o-3", - "claude-opus-4-5": "claude-opus-4-5", - "o1": "o1", - "gpt-image-1": "gpt-image-1", - "text-embedding-ada-002": "text-embedding-ada-002", - "sora-2": "sora-2", - }, + Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), + APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), ClientID: schemas.NewEnvVar("env.AZURE_CLIENT_ID"), ClientSecret: schemas.NewEnvVar("env.AZURE_CLIENT_SECRET"), TenantID: schemas.NewEnvVar("env.AZURE_TENANT_ID"), @@ -287,16 +287,17 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, }, { Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "whisper": "whisper", + "whisper-1": "whisper", + "gpt-4o-mini-tts": "gpt-4o-mini-tts", + "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "whisper": "whisper", - "gpt-4o-mini-tts": "gpt-4o-mini-tts", - "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", - }, }, }, }, nil @@ -330,15 +331,15 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), Models: []string{"claude-sonnet-4-5", "claude-4.5-haiku", "claude-opus-4-5"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-4.5-haiku": "claude-haiku-4-5@20251001", + "claude-opus-4-5": "claude-opus-4-5", + }, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), Region: *schemas.NewEnvVar(getEnvWithDefault("VERTEX_REGION_ANTHROPIC", "us-east5")), AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), - Deployments: map[string]string{ - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-4.5-haiku": "claude-haiku-4-5@20251001", - "claude-opus-4-5": "claude-opus-4-5", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -347,7 +348,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.MISTRAL_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -356,7 +357,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.GROQ_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -365,7 +366,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.PARASAIL_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -374,7 +375,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.ELEVENLABS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -383,7 +384,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.PERPLEXITY_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -392,7 +393,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.CEREBRAS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -401,7 +402,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.GEMINI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -410,7 +411,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENROUTER_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -419,7 +420,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.HUGGING_FACE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -428,7 +429,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.NEBIUS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -437,7 +438,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.XAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -446,7 +447,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -455,7 +456,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.RUNWAY_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -826,53 +827,54 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageVariationModel: "dall-e-2", ChatAudioModel: "gpt-4o-mini-audio-preview", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - TextCompletionStream: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + TextCompletionStream: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, // OpenAI supports TTS - SpeechSynthesisStream: true, // OpenAI supports streaming TTS - Transcription: true, // OpenAI supports STT with Whisper - TranscriptionStream: true, // OpenAI supports streaming STT - ImageGeneration: true, // OpenAI supports image generation with DALL-E - ImageGenerationStream: true, // OpenAI supports streaming image generation - ImageEdit: true, // OpenAI supports image editing - ImageEditStream: true, // OpenAI supports streaming image editing - ImageVariation: true, // OpenAI supports image variation - ImageVariationStream: false, // OpenAI does not support streaming image variation - Embedding: true, - Reasoning: true, // OpenAI supports reasoning via o1 models - ListModels: true, - BatchCreate: true, // OpenAI supports batch API - BatchList: true, // OpenAI supports batch API - BatchRetrieve: true, // OpenAI supports batch API - BatchCancel: true, // OpenAI supports batch API - BatchResults: true, // OpenAI supports batch API - FileUpload: true, // OpenAI supports file API - FileList: true, // OpenAI supports file API - FileRetrieve: true, // OpenAI supports file API - FileDelete: true, // OpenAI supports file API - FileContent: true, // OpenAI supports file API - ChatAudio: true, // OpenAI supports chat audio - ContainerCreate: true, // OpenAI supports container API - ContainerList: true, // OpenAI supports container API - ContainerRetrieve: true, // OpenAI supports container API - ContainerDelete: true, // OpenAI supports container API - ContainerFileCreate: true, // OpenAI supports container file API - ContainerFileList: true, // OpenAI supports container file API - ContainerFileRetrieve: true, // OpenAI supports container file API - ContainerFileContent: true, // OpenAI supports container file API - ContainerFileDelete: true, // OpenAI supports container file API + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, // OpenAI supports TTS + SpeechSynthesisStream: true, // OpenAI supports streaming TTS + Transcription: true, // OpenAI supports STT with Whisper + TranscriptionStream: true, // OpenAI supports streaming STT + ImageGeneration: true, // OpenAI supports image generation with DALL-E + ImageGenerationStream: true, // OpenAI supports streaming image generation + ImageEdit: true, // OpenAI supports image editing + ImageEditStream: true, // OpenAI supports streaming image editing + ImageVariation: true, // OpenAI supports image variation + ImageVariationStream: false, // OpenAI does not support streaming image variation + Embedding: true, + Reasoning: true, // OpenAI supports reasoning via o1 models + ListModels: true, + BatchCreate: true, // OpenAI supports batch API + BatchList: true, // OpenAI supports batch API + BatchRetrieve: true, // OpenAI supports batch API + BatchCancel: true, // OpenAI supports batch API + BatchResults: true, // OpenAI supports batch API + FileUpload: true, // OpenAI supports file API + FileList: true, // OpenAI supports file API + FileRetrieve: true, // OpenAI supports file API + FileDelete: true, // OpenAI supports file API + FileContent: true, // OpenAI supports file API + ChatAudio: true, // OpenAI supports chat audio + ContainerCreate: true, // OpenAI supports container API + ContainerList: true, // OpenAI supports container API + ContainerRetrieve: true, // OpenAI supports container API + ContainerDelete: true, // OpenAI supports container API + ContainerFileCreate: true, // OpenAI supports container file API + ContainerFileList: true, // OpenAI supports container file API + ContainerFileRetrieve: true, // OpenAI supports container file API + ContainerFileContent: true, // OpenAI supports container file API + ContainerFileDelete: true, // OpenAI supports container file API }, Fallbacks: []schemas.Fallback{ {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, @@ -883,37 +885,38 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "claude-3-7-sonnet-20250219", TextModel: "", // Anthropic doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - PromptCaching: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Anthropic does not support image editing - ImageEditStream: false, // Anthropic does not support streaming image editing - ImageVariation: false, // Anthropic does not support image variation - ImageVariationStream: false, // Anthropic does not support streaming image variation - ListModels: true, - BatchCreate: true, // Anthropic supports batch API - BatchList: true, // Anthropic supports batch API - BatchRetrieve: true, // Anthropic supports batch API - BatchCancel: true, // Anthropic supports batch API - BatchResults: true, // Anthropic supports batch API + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + PromptCaching: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Anthropic does not support image editing + ImageEditStream: false, // Anthropic does not support streaming image editing + ImageVariation: false, // Anthropic does not support image variation + ImageVariationStream: false, // Anthropic does not support streaming image variation + ListModels: true, + BatchCreate: true, // Anthropic supports batch API + BatchList: true, // Anthropic supports batch API + BatchRetrieve: true, // Anthropic supports batch API + BatchCancel: true, // Anthropic supports batch API + BatchResults: true, // Anthropic supports batch API }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -926,42 +929,43 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageEditModel: "amazon.titan-image-generator-v1", ImageVariationModel: "amazon.titan-image-generator-v1", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported for Claude - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - PromptCaching: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: true, // Bedrock supports image editing - ImageEditStream: false, // Bedrock does not support streaming image editing - ImageVariation: true, // Bedrock supports image variation - ImageVariationStream: false, // Bedrock does not support streaming image variation - ListModels: true, - BatchCreate: true, // Bedrock supports batch via Model Invocation Jobs (requires S3 config) - BatchList: true, // Bedrock supports listing batch jobs - BatchRetrieve: true, // Bedrock supports retrieving batch jobs - BatchCancel: true, // Bedrock supports stopping batch jobs - BatchResults: true, // Bedrock batch results via S3 - FileUpload: true, // Bedrock file upload to S3 (requires S3 config) - FileList: true, // Bedrock file list from S3 (requires S3 config) - FileRetrieve: true, // Bedrock file retrieve from S3 (requires S3 config) - FileDelete: true, // Bedrock file delete from S3 (requires S3 config) - FileContent: true, // Bedrock file content from S3 (requires S3 config) + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + PromptCaching: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: true, // Bedrock supports image editing + ImageEditStream: false, // Bedrock does not support streaming image editing + ImageVariation: true, // Bedrock supports image variation + ImageVariationStream: false, // Bedrock does not support streaming image variation + ListModels: true, + BatchCreate: true, // Bedrock supports batch via Model Invocation Jobs (requires S3 config) + BatchList: true, // Bedrock supports listing batch jobs + BatchRetrieve: true, // Bedrock supports retrieving batch jobs + BatchCancel: true, // Bedrock supports stopping batch jobs + BatchResults: true, // Bedrock batch results via S3 + FileUpload: true, // Bedrock file upload to S3 (requires S3 config) + FileList: true, // Bedrock file list from S3 (requires S3 config) + FileRetrieve: true, // Bedrock file retrieve from S3 (requires S3 config) + FileDelete: true, // Bedrock file delete from S3 (requires S3 config) + FileContent: true, // Bedrock file content from S3 (requires S3 config) }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -972,31 +976,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "command-a-03-2025", TextModel: "", // Cohere focuses on chat Scenarios: TestScenarios{ - TextCompletion: false, // Not typical for Cohere - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: false, // May not support automatic - ImageURL: false, // Check if supported - ImageBase64: false, // Check if supported - MultipleImages: false, // Check if supported - CompleteEnd2End: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Cohere does not support image editing - ImageEditStream: false, // Cohere does not support streaming image editing - ImageVariation: false, // Cohere does not support image variation - ImageVariationStream: false, // Cohere does not support streaming image variation - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Cohere does not support image editing + ImageEditStream: false, // Cohere does not support streaming image editing + ImageVariation: false, // Cohere does not support image variation + ImageVariationStream: false, // Cohere does not support streaming image variation + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1012,42 +1017,43 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "gpt-image-1", ImageEditModel: "dall-e-2", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, // Supported via gpt-4o-mini-tts - SpeechSynthesisStream: true, // Supported via gpt-4o-mini-tts - Transcription: true, // Supported via whisper-1 - TranscriptionStream: false, // Not properly supported yet by Azure - Embedding: true, - ImageGeneration: false, // Skipped for Azure - ImageGenerationStream: false, // Skipped for Azure - ImageEdit: true, // Azure supports image editing - ImageEditStream: true, // Azure supports streaming image editing - ImageVariation: false, // Azure does not support image variation - ImageVariationStream: false, // Azure does not support streaming image variation - ListModels: true, - BatchCreate: true, // Azure supports batch API - BatchList: true, // Azure supports batch API - BatchRetrieve: true, // Azure supports batch API - BatchCancel: true, // Azure supports batch API - BatchResults: true, // Azure supports batch API - FileUpload: true, // Azure supports file API - FileList: true, // Azure supports file API - FileRetrieve: true, // Azure supports file API - FileDelete: true, // Azure supports file API - FileContent: true, // Azure supports file API - ChatAudio: true, // Azure supports chat audio + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, // Supported via gpt-4o-mini-tts + SpeechSynthesisStream: true, // Supported via gpt-4o-mini-tts + Transcription: true, // Supported via whisper-1 + TranscriptionStream: false, // Not properly supported yet by Azure + Embedding: true, + ImageGeneration: false, // Skipped for Azure + ImageGenerationStream: false, // Skipped for Azure + ImageEdit: true, // Azure supports image editing + ImageEditStream: true, // Azure supports streaming image editing + ImageVariation: false, // Azure does not support image variation + ImageVariationStream: false, // Azure does not support streaming image variation + ListModels: true, + BatchCreate: true, // Azure supports batch API + BatchList: true, // Azure supports batch API + BatchRetrieve: true, // Azure supports batch API + BatchCancel: true, // Azure supports batch API + BatchResults: true, // Azure supports batch API + FileUpload: true, // Azure supports file API + FileList: true, // Azure supports file API + FileRetrieve: true, // Azure supports file API + FileDelete: true, // Azure supports file API + FileContent: true, // Azure supports file API + ChatAudio: true, // Azure supports chat audio }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1061,31 +1067,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "imagen-4.0-generate-001", ImageEditModel: "imagen-4.0-generate-001", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, // Vertex supports image editing - ImageEditStream: false, // Vertex does not support streaming image editing - ImageVariation: false, // Vertex does not support image variation - ImageVariationStream: false, // Vertex does not support streaming image variation - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, // Vertex supports image editing + ImageEditStream: false, // Vertex does not support streaming image editing + ImageVariation: false, // Vertex does not support image variation + ImageVariationStream: false, // Vertex does not support streaming image variation + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1097,30 +1104,31 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "", // Mistral focuses on chat TranscriptionModel: "voxtral-mini-latest", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: true, // Supported via voxtral-mini-latest - TranscriptionStream: true, // Supported via voxtral-mini-latest - Embedding: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Mistral does not support image editing - ImageEditStream: false, // Mistral does not support streaming image editing - ImageVariation: false, // Mistral does not support image variation - ImageVariationStream: false, // Mistral does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: true, // Supported via voxtral-mini-latest + TranscriptionStream: true, // Supported via voxtral-mini-latest + Embedding: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Mistral does not support image editing + ImageEditStream: false, // Mistral does not support streaming image editing + ImageVariation: false, // Mistral does not support image variation + ImageVariationStream: false, // Mistral does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1131,31 +1139,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama3.2", TextModel: "", // Ollama focuses on chat Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Ollama does not support image editing - ImageEditStream: false, // Ollama does not support streaming image editing - ImageVariation: false, // Ollama does not support image variation - ImageVariationStream: false, // Ollama does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Ollama does not support image editing + ImageEditStream: false, // Ollama does not support streaming image editing + ImageVariation: false, // Ollama does not support image variation + ImageVariationStream: false, // Ollama does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1166,31 +1175,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama-3.3-70b-versatile", TextModel: "", // Groq doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Groq does not support image editing - ImageEditStream: false, // Groq does not support streaming image editing - ImageVariation: false, // Groq does not support image variation - ImageVariationStream: false, // Groq does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Groq does not support image editing + ImageEditStream: false, // Groq does not support streaming image editing + ImageVariation: false, // Groq does not support image variation + ImageVariationStream: false, // Groq does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1231,31 +1241,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama-3.3-70b-versatile", TextModel: "", // Custom OpenAI instance doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, - SimpleChat: true, // Enable simple chat for testing - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, + SimpleChat: true, // Enable simple chat for testing + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, // ProviderOpenAICustom does not support image generation - ImageGenerationStream: false, // ProviderOpenAICustom does not support streaming image generation - ImageEdit: false, // ProviderOpenAICustom does not support image editing - ImageEditStream: false, // ProviderOpenAICustom does not support streaming image editing - ImageVariation: false, // ProviderOpenAICustom does not support image variation - ImageVariationStream: false, // ProviderOpenAICustom does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, // ProviderOpenAICustom does not support image generation + ImageGenerationStream: false, // ProviderOpenAICustom does not support streaming image generation + ImageEdit: false, // ProviderOpenAICustom does not support image editing + ImageEditStream: false, // ProviderOpenAICustom does not support streaming image editing + ImageVariation: false, // ProviderOpenAICustom does not support image variation + ImageVariationStream: false, // ProviderOpenAICustom does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1271,41 +1282,42 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "imagen-4.0-generate-001", ImageEditModel: "imagen-4.0-generate-001", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, - SpeechSynthesisStream: true, - Transcription: true, - TranscriptionStream: true, - Embedding: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, // Gemini supports image editing - ImageEditStream: false, // Gemini does not support streaming image editing - ImageVariation: false, // Gemini does not support image variation - ImageVariationStream: false, // Gemini does not support streaming image variation - ListModels: true, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: false, // Gemini doesn't support direct content download + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, // Gemini supports image editing + ImageEditStream: false, // Gemini does not support streaming image editing + ImageVariation: false, // Gemini does not support image variation + ImageVariationStream: false, // Gemini does not support streaming image variation + ListModels: true, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: false, // Gemini doesn't support direct content download }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1316,31 +1328,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "openai/gpt-4o", TextModel: "google/gemini-2.5-flash", Scenarios: TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // OpenRouter does not support image editing - ImageEditStream: false, // OpenRouter does not support streaming image editing - ImageVariation: false, // OpenRouter does not support image variation - ImageVariationStream: false, // OpenRouter does not support streaming image variation - SpeechSynthesis: false, - SpeechSynthesisStream: false, - Transcription: false, - TranscriptionStream: false, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // OpenRouter does not support image editing + ImageEditStream: false, // OpenRouter does not support streaming image editing + ImageVariation: false, // OpenRouter does not support image variation + ImageVariationStream: false, // OpenRouter does not support streaming image variation + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: false, + TranscriptionStream: false, + Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1394,31 +1407,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "", // XAI focuses on chat ImageGenerationModel: "grok-2-image", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, // Not supported - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: false, // XAI does not support image editing - ImageEditStream: false, // XAI does not support streaming image editing - ImageVariation: false, // XAI does not support image variation - ImageVariationStream: false, // XAI does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, // Not supported + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: false, // XAI does not support image editing + ImageEditStream: false, // XAI does not support streaming image editing + ImageVariation: false, // XAI does not support image variation + ImageVariationStream: false, // XAI does not support streaming image variation + ListModels: true, }, }, { @@ -1427,27 +1441,28 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "openai/gpt-4.1-mini", ImageGenerationModel: "black-forest-labs/flux-dev", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, // Not supported - ListModels: true, - ImageGeneration: true, - ImageGenerationStream: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, // Not supported + ListModels: true, + ImageGeneration: true, + ImageGenerationStream: false, }, }, { Provider: schemas.VLLM, @@ -1456,27 +1471,28 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ EmbeddingModel: "Qwen/Qwen3-Embedding-0.6B", TranscriptionModel: "openai/whisper-small", Scenarios: TestScenarios{ - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: true, // VLLM supports transcription - TranscriptionStream: true, // VLLM supports transcription streaming - Embedding: true, // VLLM supports embedding - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // VLLM does not support image editing - ImageEditStream: false, // VLLM does not support streaming image editing - ImageVariation: false, // VLLM does not support image variation - ImageVariationStream: false, // VLLM does not support streaming image variation - ListModels: true, - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: true, // VLLM supports transcription + TranscriptionStream: true, // VLLM supports transcription streaming + Embedding: true, // VLLM supports embedding + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // VLLM does not support image editing + ImageEditStream: false, // VLLM does not support streaming image editing + ImageVariation: false, // VLLM does not support image variation + ImageVariationStream: false, // VLLM does not support streaming image variation + ListModels: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, + End2EndToolCalling: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, diff --git a/core/internal/llmtests/image_edit.go b/core/internal/llmtests/image_edit.go index 56ad66d502..deed0bd820 100644 --- a/core/internal/llmtests/image_edit.go +++ b/core/internal/llmtests/image_edit.go @@ -364,8 +364,8 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context t.Error("❌ ExtraFields.Provider is empty") } - if imageEditResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageEditResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageEditRequest @@ -374,7 +374,7 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Logf("βœ… Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.ModelRequested, len(imageEditResponse.Data)) + imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data)) }) } diff --git a/core/internal/llmtests/image_generation.go b/core/internal/llmtests/image_generation.go index 81a0626978..1516ff0088 100644 --- a/core/internal/llmtests/image_generation.go +++ b/core/internal/llmtests/image_generation.go @@ -145,12 +145,12 @@ func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Error("❌ ExtraFields.Provider is empty") } - if imageGenerationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } t.Logf("βœ… Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.ModelRequested, len(imageGenerationResponse.Data)) + imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data)) }) } diff --git a/core/internal/llmtests/image_variation.go b/core/internal/llmtests/image_variation.go index 0aca33a63f..d0c4d18e78 100644 --- a/core/internal/llmtests/image_variation.go +++ b/core/internal/llmtests/image_variation.go @@ -162,8 +162,8 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co t.Error("❌ ExtraFields.Provider is empty") } - if imageVariationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageVariationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageVariationRequest @@ -172,7 +172,7 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Logf("βœ… Image variation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.ModelRequested, len(imageVariationResponse.Data)) + imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.OriginalModelRequested, len(imageVariationResponse.Data)) }) } diff --git a/core/internal/llmtests/realtime.go b/core/internal/llmtests/realtime.go index 821aeba9eb..400f5f9cda 100644 --- a/core/internal/llmtests/realtime.go +++ b/core/internal/llmtests/realtime.go @@ -43,7 +43,7 @@ func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.RealtimeModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/llmtests/response_validation.go b/core/internal/llmtests/response_validation.go index b63fe8ad7e..baa5260cca 100644 --- a/core/internal/llmtests/response_validation.go +++ b/core/internal/llmtests/response_validation.go @@ -847,7 +847,7 @@ func validateResponsesBasicStructure(response *schemas.BifrostResponsesResponse, } provider := response.ExtraFields.Provider - model := response.ExtraFields.ModelDeployment + model := response.ExtraFields.ResolvedModelUsed // Verify top level status is present for OpenAI and Azure with non-Claude models if provider != "" && (provider == schemas.OpenAI || provider == schemas.Azure) && !strings.Contains(strings.ToLower(model), "claude") { @@ -976,8 +976,7 @@ func validateResponsesTechnicalFields(t *testing.T, response *schemas.BifrostRes // Check model field if expectations.ShouldHaveModel { - if strings.TrimSpace(response.Model) == "" && - strings.TrimSpace(response.ExtraFields.ModelDeployment) == "" { + if strings.TrimSpace(response.Model) == "" { result.Passed = false result.Errors = append(result.Errors, fmt.Sprintf("Expected model field but not present or empty (provider: %s)", response.ExtraFields.Provider)) } diff --git a/core/internal/llmtests/speech_synthesis.go b/core/internal/llmtests/speech_synthesis.go index 4e08d6e2c8..aae66423a3 100644 --- a/core/internal/llmtests/speech_synthesis.go +++ b/core/internal/llmtests/speech_synthesis.go @@ -239,8 +239,8 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c t.Fatalf("HD audio data too small: got %d bytes, expected at least 5000", audioSize) } - if speechResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.ModelRequested) + if speechResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.OriginalModelRequested) } t.Logf("βœ… HD speech synthesis successful: %d bytes generated", len(speechResponse.Audio)) @@ -344,8 +344,8 @@ func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpee t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes) } - if expectedModel != "" && response.ExtraFields.ModelRequested != expectedModel { - t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.ModelRequested) + if expectedModel != "" && response.ExtraFields.OriginalModelRequested != expectedModel { + t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.OriginalModelRequested) } t.Logf("βœ… Audio validation passed: %d bytes generated", audioSize) diff --git a/core/internal/llmtests/speech_synthesis_stream.go b/core/internal/llmtests/speech_synthesis_stream.go index 87268f3c17..8b7bdc8efb 100644 --- a/core/internal/llmtests/speech_synthesis_stream.go +++ b/core/internal/llmtests/speech_synthesis_stream.go @@ -184,8 +184,8 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type) } - if response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } @@ -348,8 +348,8 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, t.Logf("βœ… HD chunk %d: %d bytes", chunkCount, chunkSize) } - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } diff --git a/core/internal/llmtests/transcription_stream.go b/core/internal/llmtests/transcription_stream.go index dfc80fc533..a28239c00f 100644 --- a/core/internal/llmtests/transcription_stream.go +++ b/core/internal/llmtests/transcription_stream.go @@ -242,8 +242,12 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type) } - if response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != testConfig.TranscriptionModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested) + gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested + if gotModel == "" { + t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested") + } + if gotModel != testConfig.TranscriptionModel { + t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel) } lastResponse = DeepCopyBifrostStreamChunk(response) diff --git a/core/internal/llmtests/video.go b/core/internal/llmtests/video.go index c622edf6b4..8ac2d6e396 100644 --- a/core/internal/llmtests/video.go +++ b/core/internal/llmtests/video.go @@ -48,8 +48,8 @@ func RunVideoGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C if resp.ExtraFields.Provider == "" { t.Fatal("❌ Video generation extra_fields.provider is empty") } - if resp.ExtraFields.ModelRequested == "" { - t.Fatal("❌ Video generation extra_fields.model_requested is empty") + if resp.ExtraFields.OriginalModelRequested == "" { + t.Fatal("❌ Video generation extra_fields.original_model_requested is empty") } t.Logf("βœ… Video generation created job: id=%s status=%s", resp.ID, resp.Status) diff --git a/core/internal/llmtests/websocket_responses.go b/core/internal/llmtests/websocket_responses.go index 420a049fb7..966463dade 100644 --- a/core/internal/llmtests/websocket_responses.go +++ b/core/internal/llmtests/websocket_responses.go @@ -38,7 +38,7 @@ func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx contex bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.ChatModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/mcptests/agent_test_helpers.go b/core/internal/mcptests/agent_test_helpers.go index d19d953ca0..85512dcce6 100644 --- a/core/internal/mcptests/agent_test_helpers.go +++ b/core/internal/mcptests/agent_test_helpers.go @@ -131,11 +131,11 @@ func SetupAgentTest(t *testing.T, config AgentTestConfig) (*mcp.MCPManager, *Dyn // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -192,11 +192,11 @@ func SetupAgentTestWithClients(t *testing.T, config AgentTestConfig, customClien // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) diff --git a/core/internal/mcptests/codemode_stdio_test.go b/core/internal/mcptests/codemode_stdio_test.go index 8fe5841a82..aab3a15172 100644 --- a/core/internal/mcptests/codemode_stdio_test.go +++ b/core/internal/mcptests/codemode_stdio_test.go @@ -56,27 +56,27 @@ func setupCodeModeWithSTDIOServers(t *testing.T, serverNames ...string) (*mcp.MC config = GetTemperatureMCPClientConfig(bifrostRoot) config.IsCodeModeClient = true config.ID = "temperature-client" // Match test expectations - config.Name = "temperature" // Use lowercase to match test code + config.Name = "temperature" // Use lowercase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "go-test-server": config = GetGoTestServerConfig(bifrostRoot) config.ID = "goTestServer-client" // Match test expectations - config.Name = "goTestServer" // Use camelCase to match test code + config.Name = "goTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "edge-case-server": config = GetEdgeCaseServerConfig(bifrostRoot) config.ID = "edgeCaseServer-client" // Match test expectations - config.Name = "edgeCaseServer" // Use camelCase to match test code + config.Name = "edgeCaseServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "error-test-server": config = GetErrorTestServerConfig(bifrostRoot) config.ID = "errorTestServer-client" // Match test expectations - config.Name = "errorTestServer" // Use camelCase to match test code + config.Name = "errorTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "parallel-test-server": config = GetParallelTestServerConfig(bifrostRoot) config.ID = "parallelTestServer-client" // Match test expectations - config.Name = "parallelTestServer" // Use camelCase to match test code + config.Name = "parallelTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "test-tools-server": // test-tools-server doesn't have a fixture, set up manually @@ -367,9 +367,9 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { expectedError string }{ { - name: "allow_only_test_tools_server", - includeClients: []string{"testToolsServer"}, - code: `result = testToolsServer.echo(message="allowed")`, + name: "allow_only_test_tools_server", + includeClients: []string{"testToolsServer"}, + code: `result = testToolsServer.echo(message="allowed")`, shouldSucceed: true, expectedInResult: "allowed", }, @@ -377,13 +377,13 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_test_tools_server", includeClients: []string{"temperature"}, code: `result = testToolsServer.echo(message="blocked")`, - shouldSucceed: false, - expectedError: "undefined: testToolsServer", + shouldSucceed: false, + expectedError: "undefined: testToolsServer", }, { - name: "allow_only_temperature_server", - includeClients: []string{"temperature"}, - code: `result = temperature.get_temperature(location="Paris")`, + name: "allow_only_temperature_server", + includeClients: []string{"temperature"}, + code: `result = temperature.get_temperature(location="Paris")`, shouldSucceed: true, expectedInResult: "Paris", }, @@ -391,8 +391,8 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_temperature_server", includeClients: []string{"testToolsServer"}, code: `result = temperature.get_temperature(location="blocked")`, - shouldSucceed: false, - expectedError: "undefined: temperature", + shouldSucceed: false, + expectedError: "undefined: temperature", }, { name: "allow_both_servers", @@ -409,7 +409,7 @@ result = {"echo": echo, "temp": temp}`, t.Run(tc.name, func(t *testing.T) { // Create context with client filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied at tool listing level @@ -524,7 +524,7 @@ result = {"echo": echo, "calc": calc}`, t.Run(tc.name, func(t *testing.T) { // Create context with tool filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied @@ -622,10 +622,10 @@ result = {"echo": echo, "temp": temp}`, // Create context with both client and tool filtering baseCtx := context.Background() if tc.includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) } if tc.includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -1692,7 +1692,7 @@ result = {"count": 3}`, for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) toolCall := CreateExecuteToolCodeCall(fmt.Sprintf("call-%s", tc.name), tc.code) diff --git a/core/internal/mcptests/concurrency_advanced_test.go b/core/internal/mcptests/concurrency_advanced_test.go index a1c3823831..e3c5793df4 100644 --- a/core/internal/mcptests/concurrency_advanced_test.go +++ b/core/internal/mcptests/concurrency_advanced_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -533,14 +532,14 @@ func TestConcurrent_FilteringChanges(t *testing.T) { if id%2 == 0 { // Even: allow all tools baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } else { // Odd: allow only echo baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/fixtures.go b/core/internal/mcptests/fixtures.go index 88b00a9f70..f760ae5ac0 100644 --- a/core/internal/mcptests/fixtures.go +++ b/core/internal/mcptests/fixtures.go @@ -1422,7 +1422,7 @@ func (a *testAccount) GetKeysForProvider(ctx context.Context, providerKey schema return []schemas.Key{ { Value: *schemas.NewEnvVar(apiKey), - Models: []string{}, // Empty means all models + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, nil @@ -1460,6 +1460,17 @@ func setupBifrost(t *testing.T) *bifrost.Bifrost { return bifrostInstance } +// noopPluginPipeline is a passthrough pipeline used in tests that don't need plugin hooks. +type noopPluginPipeline struct{} + +func (n *noopPluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) { + return req, nil, 0 +} + +func (n *noopPluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { + return mcpResp, bifrostErr +} + // setupMCPManager creates an MCP manager for testing func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mcp.MCPManager { t.Helper() @@ -1472,9 +1483,14 @@ func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mc clientConfigPtrs[i] = &clientConfigs[i] } - // Create MCP config + // Create MCP config with a no-op plugin pipeline so that codemode tool calls + // work correctly even when no Bifrost instance is attached. mcpConfig := &schemas.MCPConfig{ ClientConfigs: clientConfigPtrs, + PluginPipelineProvider: func() interface{} { + return &noopPluginPipeline{} + }, + ReleasePluginPipeline: func(pipeline interface{}) {}, } // Create Starlark CodeMode @@ -1984,10 +2000,10 @@ func AssertExecutionTimeUnder(t *testing.T, fn func(), maxDuration time.Duration func CreateTestContextWithMCPFilter(includeClients []string, includeTools []string) *schemas.BifrostContext { baseCtx := context.Background() if includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, includeClients) } if includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, includeTools) } return schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/tool_filtering_test.go b/core/internal/mcptests/tool_filtering_test.go index eb8b370a28..15fde03d75 100644 --- a/core/internal/mcptests/tool_filtering_test.go +++ b/core/internal/mcptests/tool_filtering_test.go @@ -160,7 +160,7 @@ func TestToolsToExecute_ExplicitList(t *testing.T) { // Verify configuration was set correctly clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) } func TestToolsToExecute_SingleTool(t *testing.T) { @@ -178,10 +178,10 @@ func TestToolsToExecute_SingleTool(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) // Verify it's not allow-all - assert.NotEqual(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute, "should not be wildcard") + assert.NotEqual(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute, "should not be wildcard") } // ============================================================================= @@ -204,8 +204,8 @@ func TestToolsToAutoExecute_Basic(t *testing.T) { // Verify the client was created with correct configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToAutoExecute) } func TestToolsToAutoExecute_NotInExecuteList(t *testing.T) { @@ -224,8 +224,8 @@ func TestToolsToAutoExecute_NotInExecuteList(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.Equal(t, []string{"hash"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"hash"}, clients[0].ExecutionConfig.ToolsToAutoExecute) assert.NotEqual(t, clients[0].ExecutionConfig.ToolsToExecute, clients[0].ExecutionConfig.ToolsToAutoExecute) } @@ -245,7 +245,7 @@ func TestToolsToAutoExecute_Wildcard(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToAutoExecute) } // ============================================================================= @@ -267,7 +267,7 @@ func TestContextFilteringRestrictsWildcard(t *testing.T) { // Verify client configuration allows all clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute) // Context restricts to only specific tools (verify context works separately) ctx := CreateTestContextWithMCPFilter(nil, []string{"encode"}) @@ -305,9 +305,9 @@ func TestFilteringMultipleClients_DifferentRules(t *testing.T) { // Find and verify each client for _, client := range clients { if client.ExecutionConfig.ID == "stdio-client-1" { - assert.Equal(t, []string{"encode"}, client.ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, client.ExecutionConfig.ToolsToExecute) } else if client.ExecutionConfig.ID == "stdio-client-2" { - assert.Equal(t, []string{"*"}, client.ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"*"}, client.ExecutionConfig.ToolsToExecute) } } } @@ -331,7 +331,7 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Verify initial configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) // Edit client to only allow second tool clientConfig.ToolsToExecute = []string{"hash"} @@ -341,6 +341,6 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Verify configuration changed clients = manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"hash"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.NotEqual(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"hash"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.NotEqual(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) } diff --git a/core/mcp/agent.go b/core/mcp/agent.go index fe4481ad7a..96d16ec24e 100644 --- a/core/mcp/agent.go +++ b/core/mcp/agent.go @@ -1,6 +1,7 @@ package mcp import ( + "errors" "fmt" "strings" "sync" @@ -10,7 +11,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) - type AgentModeExecutor struct { logger schemas.Logger } @@ -40,7 +40,7 @@ func (a *AgentModeExecutor) ExecuteAgentForChatRequest( makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Create adapter for Chat API adapter := &chatAPIAdapter{ @@ -143,7 +143,7 @@ func (a *AgentModeExecutor) executeAgent( adapter agentAPIAdapter, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (interface{}, *schemas.BifrostError) { // Get initial response from adapter currentResponse := adapter.getInitialResponse() @@ -157,6 +157,9 @@ func (a *AgentModeExecutor) executeAgent( allExecutedToolResults := make([]*schemas.ChatMessage, 0) allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + // Accumulate token usage across all LLM calls in the agent loop + accumulatedUsage := adapter.extractUsage(currentResponse) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) if ok { ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) @@ -207,14 +210,8 @@ func (a *AgentModeExecutor) executeAgent( continue } - // Step 1: Convert literal \n escape sequences to actual newlines for parsing - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - if len(codeWithNewlines) != len(code) { - a.logger.Debug("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix) - } - - // Step 2: Extract tool calls from code during AST formation - extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) + // Step 1: Extract tool calls from the original source code during validation + extractedToolCalls, err := extractToolCallsFromCode(code) if err != nil { a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) @@ -289,6 +286,8 @@ func (a *AgentModeExecutor) executeAgent( wg := sync.WaitGroup{} wg.Add(len(autoExecutableTools)) channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) + var authRequiredErr *schemas.MCPUserOAuthRequiredError + var authRequiredOnce sync.Once for _, toolCall := range autoExecutableTools { go func(toolCall schemas.ChatAssistantMessageToolCall) { defer wg.Done() @@ -305,6 +304,15 @@ func (a *AgentModeExecutor) executeAgent( mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest) if toolErr != nil { + // Check if this is a per-user OAuth auth-required error + var oauthErr *schemas.MCPUserOAuthRequiredError + if errors.As(toolErr, &oauthErr) { + authRequiredOnce.Do(func() { + authRequiredErr = oauthErr + }) + channelToolResults <- createToolResultMessage(toolCall, "", toolErr) + return + } a.logger.Warn("Tool execution failed: %v", toolErr) channelToolResults <- createToolResultMessage(toolCall, "", toolErr) } else if mcpResponse != nil && mcpResponse.ChatMessage != nil { @@ -321,6 +329,23 @@ func (a *AgentModeExecutor) executeAgent( wg.Wait() close(channelToolResults) + // If any tool required per-user OAuth, stop the agent loop and return the error + if authRequiredErr != nil { + statusCode := 401 + errType := "mcp_auth_required" + return nil, &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: authRequiredErr.Message, + Type: &errType, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + MCPAuthRequired: authRequiredErr, + }, + } + } + // Collect tool results executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) for toolResult := range channelToolResults { @@ -342,6 +367,8 @@ func (a *AgentModeExecutor) executeAgent( if depth == 1 && len(allExecutedToolResults) == 0 { return currentResponse, nil } + // Apply accumulated usage before building the final response + adapter.applyUsage(currentResponse, accumulatedUsage) // Create response with all executed tool results from all iterations, and non-auto-executable tool calls return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil } @@ -364,11 +391,127 @@ func (a *AgentModeExecutor) executeAgent( } currentResponse = response + accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse)) } + adapter.applyUsage(currentResponse, accumulatedUsage) return currentResponse, nil } +// mergeUsage sums token counts and costs from two BifrostLLMUsage values. +// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is. +func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage { + if add == nil { + return base + } + if base == nil { + return add + } + + merged := &schemas.BifrostLLMUsage{ + PromptTokens: base.PromptTokens + add.PromptTokens, + CompletionTokens: base.CompletionTokens + add.CompletionTokens, + TotalTokens: base.TotalTokens + add.TotalTokens, + } + + // Merge prompt token details + if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil { + bd := base.PromptTokensDetails + ad := add.PromptTokensDetails + if bd == nil { + bd = &schemas.ChatPromptTokensDetails{} + } + if ad == nil { + ad = &schemas.ChatPromptTokensDetails{} + } + merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: bd.TextTokens + ad.TextTokens, + AudioTokens: bd.AudioTokens + ad.AudioTokens, + ImageTokens: bd.ImageTokens + ad.ImageTokens, + CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens, + CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens, + } + } + + // Merge completion token details + if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil { + bd := base.CompletionTokensDetails + ad := add.CompletionTokensDetails + if bd == nil { + bd = &schemas.ChatCompletionTokensDetails{} + } + if ad == nil { + ad = &schemas.ChatCompletionTokensDetails{} + } + merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: bd.TextTokens + ad.TextTokens, + AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens, + AudioTokens: bd.AudioTokens + ad.AudioTokens, + ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens, + RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens, + } + if bd.CitationTokens != nil || ad.CitationTokens != nil { + bct := 0 + act := 0 + if bd.CitationTokens != nil { + bct = *bd.CitationTokens + } + if ad.CitationTokens != nil { + act = *ad.CitationTokens + } + sum := bct + act + merged.CompletionTokensDetails.CitationTokens = &sum + } + if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil { + bnsq := 0 + ansq := 0 + if bd.NumSearchQueries != nil { + bnsq = *bd.NumSearchQueries + } + if ad.NumSearchQueries != nil { + ansq = *ad.NumSearchQueries + } + sum := bnsq + ansq + merged.CompletionTokensDetails.NumSearchQueries = &sum + } + if bd.ImageTokens != nil || ad.ImageTokens != nil { + bit := 0 + ait := 0 + if bd.ImageTokens != nil { + bit = *bd.ImageTokens + } + if ad.ImageTokens != nil { + ait = *ad.ImageTokens + } + sum := bit + ait + merged.CompletionTokensDetails.ImageTokens = &sum + } + } + + // Merge cost + if base.Cost != nil || add.Cost != nil { + bc := base.Cost + ac := add.Cost + if bc == nil { + bc = &schemas.BifrostCost{} + } + if ac == nil { + ac = &schemas.BifrostCost{} + } + merged.Cost = &schemas.BifrostCost{ + InputTokensCost: bc.InputTokensCost + ac.InputTokensCost, + OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost, + ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost, + CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost, + SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost, + RequestCost: bc.RequestCost + ac.RequestCost, + TotalCost: bc.TotalCost + ac.TotalCost, + } + } + + return merged +} + // extractToolCalls extracts all tool calls from a chat response. // It iterates through all choices in the response and collects tool calls // from assistant messages. @@ -460,25 +603,23 @@ func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager C // Get auto-executable tools from config toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute - if len(toolsToAutoExecute) == 0 { + if toolsToAutoExecute.IsEmpty() { // No auto-executable tools configured for this client continue } // Parse tool names (as they appear in JavaScript code) autoExecutableTools := []string{} - for _, originalToolName := range toolsToAutoExecute { - // Handle wildcard "*" - means all tools are auto-executable - if originalToolName == "*" { - autoExecutableTools = append(autoExecutableTools, "*") - continue + if toolsToAutoExecute.IsUnrestricted() { + autoExecutableTools = append(autoExecutableTools, "*") + } else { + for _, originalToolName := range toolsToAutoExecute { + // Replace - with _ for code mode compatibility, then parse for JS compatibility + toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") + parsedToolName := parseToolName(toolNameForCode) + autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Replace - with _ for code mode compatibility, then parse for JS compatibility - toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") - parsedToolName := parseToolName(toolNameForCode) - autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Add to map if there are auto-executable tools if len(autoExecutableTools) > 0 { allowedTools[clientName] = autoExecutableTools diff --git a/core/mcp/agentadaptors.go b/core/mcp/agentadaptors.go index 6986cd9798..7b78df4389 100644 --- a/core/mcp/agentadaptors.go +++ b/core/mcp/agentadaptors.go @@ -59,6 +59,12 @@ type agentAPIAdapter interface { executedToolCalls []schemas.ChatAssistantMessageToolCall, nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, ) interface{} + + // extractUsage returns the token usage from a response as BifrostLLMUsage. + extractUsage(response interface{}) *schemas.BifrostLLMUsage + + // applyUsage sets accumulated usage on the response in place. + applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) } // chatAPIAdapter implements agentAPIAdapter for Chat API @@ -175,6 +181,14 @@ func (c *chatAPIAdapter) createResponseWithExecutedTools( ) } +func (c *chatAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage { + return response.(*schemas.BifrostChatResponse).Usage +} + +func (c *chatAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) { + response.(*schemas.BifrostChatResponse).Usage = usage +} + // createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response // that includes executed tool results and non-auto-executable tool calls. The response // contains a formatted text summary of executed tool results and includes the non-auto-executable @@ -390,6 +404,14 @@ func (r *responsesAPIAdapter) createResponseWithExecutedTools( ) } +func (r *responsesAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage { + return response.(*schemas.BifrostResponsesResponse).Usage.ToBifrostLLMUsage() +} + +func (r *responsesAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) { + response.(*schemas.BifrostResponsesResponse).Usage = usage.ToResponsesResponseUsage() +} + // createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response // that includes executed tool results and non-auto-executable tool calls. The response // contains a formatted text summary of executed tool results and includes the non-auto-executable diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index 36b12243da..b6bd442c20 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -118,6 +118,33 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { // This is to avoid deadlocks when the connection attempt is made m.mu.Unlock() + // Per-user OAuth: skip persistent connection. Auth is per-request at runtime. + // The admin verifies the configuration via a sample login before this is called, + // and tools are populated separately via SetClientTools(). + if configCopy.AuthType == schemas.MCPAuthTypePerUserOauth { + m.mu.Lock() + if client, exists := m.clientMap[config.ID]; exists { + if config.ConnectionString != nil { + url := config.ConnectionString.GetValue() + client.ConnectionInfo.ConnectionURL = &url + } + // Restore discovered tools from config (persisted in DB across restarts) + if len(config.DiscoveredTools) > 0 { + for toolName, tool := range config.DiscoveredTools { + client.ToolMap[toolName] = tool + } + client.ToolNameMapping = config.DiscoveredToolNameMapping + client.State = schemas.MCPConnectionStateConnected + m.logger.Info("%s Per-user OAuth MCP client '%s' restored with %d tools", MCPLogPrefix, config.Name, len(config.DiscoveredTools)) + } else { + client.State = schemas.MCPConnectionStatePendingTools + m.logger.Info("%s Per-user OAuth MCP client '%s' registered (connection deferred to runtime)", MCPLogPrefix, config.Name) + } + } + m.mu.Unlock() + return nil + } + // Connect using the copied config if err := m.connectToMCPClient(configCopy); err != nil { // Clean up the failed entry β€” this is a user-initiated action (UI/API), @@ -131,6 +158,92 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { return nil } +// VerifyPerUserOAuthConnection creates a temporary MCP connection using the +// provided access token to verify the server is reachable and discover available +// tools. The connection is closed after verification. This is used during +// per-user OAuth client setup when the admin does a test login to validate the +// OAuth configuration before saving the MCP client. +// +// Parameters: +// - config: MCP client configuration (connection URL, name, etc.) +// - accessToken: temporary OAuth access token from the admin's test login +// +// Returns: +// - map[string]schemas.ChatTool: discovered tools keyed by prefixed name +// - map[string]string: tool name mapping (sanitized β†’ original MCP name) +// - error: any error during verification +func (m *MCPManager) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { + return nil, nil, fmt.Errorf("connection URL is required for per-user OAuth verification") + } + + // Create HTTP transport with the admin's temporary Bearer token + headers := map[string]string{ + "Authorization": "Bearer " + accessToken, + } + httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) + if err != nil { + return nil, nil, fmt.Errorf("failed to create HTTP transport for verification: %w", err) + } + + // Create temporary MCP client + tempClient := client.NewClient(httpTransport) + ctx, cancel := context.WithTimeout(ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Start transport + if err := tempClient.Start(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to start MCP connection for verification: %w", err) + } + defer tempClient.Close() + + // Initialize MCP handshake + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s-verify", config.Name), + Version: "1.0.0", + }, + }, + } + if _, err := tempClient.Initialize(ctx, initRequest); err != nil { + return nil, nil, fmt.Errorf("failed to initialize MCP connection for verification: %w", err) + } + + // Discover tools + tools, toolNameMapping, err := retrieveExternalTools(ctx, tempClient, config.Name, m.logger) + if err != nil { + return nil, nil, fmt.Errorf("failed to discover tools during verification: %w", err) + } + + m.logger.Info("%s Per-user OAuth verification succeeded for '%s': discovered %d tools", MCPLogPrefix, config.Name, len(tools)) + return tools, toolNameMapping, nil +} + +// SetClientTools updates the tool map and name mapping for an existing client. +// This is used to populate tools discovered during per-user OAuth verification, +// where tool discovery happens separately from client creation. +// +// Parameters: +// - clientID: ID of the client to update +// - tools: discovered tools keyed by prefixed name +// - toolNameMapping: mapping from sanitized tool names to original MCP names +func (m *MCPManager) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + m.mu.Lock() + defer m.mu.Unlock() + + if client, exists := m.clientMap[clientID]; exists { + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + client.ToolNameMapping = toolNameMapping + client.State = schemas.MCPConnectionStateConnected + m.logger.Debug("%s Set %d tools on client '%s'", MCPLogPrefix, len(tools), client.Name) + } +} + // RemoveClient removes an MCP client from the manager. // It handles cleanup for all transport types (HTTP, STDIO, SSE). // @@ -243,13 +356,15 @@ func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientCon ConfigHash: client.ExecutionConfig.ConfigHash, ToolPricing: maps.Clone(client.ExecutionConfig.ToolPricing), // Updatable fields - copy from updated config with proper cloning - Name: updatedConfig.Name, - IsCodeModeClient: updatedConfig.IsCodeModeClient, - Headers: maps.Clone(updatedConfig.Headers), - ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), - ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), - IsPingAvailable: updatedConfig.IsPingAvailable, - ToolSyncInterval: updatedConfig.ToolSyncInterval, + Name: updatedConfig.Name, + IsCodeModeClient: updatedConfig.IsCodeModeClient, + Headers: maps.Clone(updatedConfig.Headers), + ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), + ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), + AllowedExtraHeaders: slices.Clone(updatedConfig.AllowedExtraHeaders), + IsPingAvailable: updatedConfig.IsPingAvailable, + ToolSyncInterval: updatedConfig.ToolSyncInterval, + AllowOnAllVirtualKeys: updatedConfig.AllowOnAllVirtualKeys, } // Atomically replace the config pointer @@ -663,7 +778,11 @@ func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error { } // Start health monitoring for the client - monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, config.IsPingAvailable, m.logger) + isPingAvailable := true + if config.IsPingAvailable != nil { + isPingAvailable = *config.IsPingAvailable + } + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, isPingAvailable, m.logger) m.healthMonitorManager.StartMonitoring(monitor) // Start tool syncing for the client (skip for internal bifrost client) diff --git a/core/mcp/codemode.go b/core/mcp/codemode.go index e81c984195..fa11e52d0b 100644 --- a/core/mcp/codemode.go +++ b/core/mcp/codemode.go @@ -3,7 +3,6 @@ package mcp import ( - "context" "sync" "time" @@ -31,7 +30,7 @@ type CodeMode interface { // ExecuteTool handles a code mode tool call by name. // Returns the response message and any error that occurred. - ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) + ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) // IsCodeModeTool returns true if the given tool name is a code mode tool. IsCodeModeTool(toolName string) bool diff --git a/core/mcp/codemode/starlark/executecode.go b/core/mcp/codemode/starlark/executecode.go index da497d8f6f..d2d9435764 100644 --- a/core/mcp/codemode/starlark/executecode.go +++ b/core/mcp/codemode/starlark/executecode.go @@ -5,7 +5,6 @@ package starlark import ( "context" "fmt" - "net/http" "strings" "time" @@ -13,9 +12,11 @@ import ( "github.com/mark3labs/mcp-go/mcp" codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" + "go.starlark.net/syntax" ) // ExecutionResult represents the result of code execution @@ -52,8 +53,11 @@ type ExecutionEnvironment struct { func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { executeToolCodeProps := schemas.NewOrderedMapFromPairs( schemas.KV("code", map[string]interface{}{ - "type": "string", - "description": "Python code to execute. The code runs in a Starlark interpreter (Python subset). Tool calls are synchronous - no async/await needed. For loops/conditionals, wrap in a function. Use print() for logging. ALWAYS retry if code fails. Example: def main():\n items = server.list_items()\n for item in items:\n print(item)\nresult = main()", + "type": "string", + "description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " + + "Use print() for logging. Assign to 'result' variable to return a value. " + + "Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " + + "Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items", }), ) return schemas.ChatTool{ @@ -61,36 +65,36 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { Function: &schemas.ChatToolFunction{ Name: codemcp.ToolTypeExecuteToolCode, Description: schemas.Ptr( - "Executes Python code inside a sandboxed Starlark interpreter with access to all connected MCP servers' tools. " + - "All connected servers are exposed as global objects named after their configuration keys, and each server " + - "provides functions for every tool available on that server. The canonical usage pattern is: " + - "result = .(param=\"value\"). Both and should be discovered " + - "using listToolFiles and readToolFile. " + - - "IMPORTANT WORKFLOW: Always follow this order β€” first use listToolFiles to see available servers and tools, " + - "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + - "to execute your code. " + + "Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " + + "Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " + + "This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " + + "Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " + + + "STARLARK DIFFERENCES FROM PYTHON β€” READ BEFORE WRITING CODE: " + + "1. NO try/except/finally/raise β€” error handling is not supported, and tool failures cannot be caught inside Starlark. " + + "2. NO classes β€” use dicts and functions. " + + "3. NO imports, direct network access, or direct filesystem access β€” use MCP tools instead. " + + "4. NO is operator β€” use == for comparison. " + + "5. NO f-strings β€” use % formatting: \"Hello %s, count=%d\" % (name, n). " + + "6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE β€” no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " + "SYNTAX NOTES: " + - "β€’ Tool calls are synchronous - NO async/await needed, just call directly: result = server.tool(arg=\"value\") " + + "β€’ Synchronous calls β€” NO async/await: result = server.tool(arg=\"value\") " + "β€’ Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " + "β€’ Access dict values with brackets: result[\"key\"] NOT result.key " + - "β€’ Use print() for logging (not console.log) " + - "β€’ List comprehensions work: [x for x in items if x[\"active\"]] " + - "β€’ To return a value, assign to 'result' variable: result = computed_value " + - "β€’ CRITICAL: for/if/while at top level MUST be inside a function - def main(): ... then result = main() " + - - "RETRY POLICY: ALWAYS retry if a code block fails. Analyze the error, adjust your code, and retry. " + - - "The environment is intentionally minimal: " + - "β€’ No imports needed or supported " + - "β€’ No network APIs (use MCP tools for external interactions) " + - "β€’ No file system access (use MCP tools) " + - "β€’ No classes (use dicts and functions) " + - "β€’ Deterministic execution (no random, no time) " + - - "Long-running operations are interrupted via execution timeout. " + - "This tool is designed specifically for orchestrating MCP tool calls and lightweight computation.", + "β€’ Use print() for logging/debugging " + + "β€’ List comprehensions: [x for x in items if x[\"active\"]] " + + "β€’ String escapes work normally: \"line1\\nline2\" produces a newline " + + "β€’ Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " + + "β€’ chr(10) for newline character, chr(9) for tab " + + "β€’ To return a value, assign to 'result': result = computed_value " + + "β€’ MCP tool calls are timeout-limited; avoid long or infinite loops " + + + "AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " + + "int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " + + + "RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.", ), Parameters: &schemas.ToolFunctionParameters{ @@ -103,7 +107,7 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { } // handleExecuteToolCode handles the executeToolCode tool call. -func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { toolName := "unknown" if toolCall.Function.Name != nil { toolName = *toolCall.Function.Name @@ -197,16 +201,13 @@ func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall s } // executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings. -func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) ExecutionResult { +func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult { logs := []string{} s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix) - // Step 1: Convert literal \n escape sequences to actual newlines - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - - // Step 2: Handle empty code - trimmedCode := strings.TrimSpace(codeWithNewlines) + // Step 1: Handle empty code + trimmedCode := strings.TrimSpace(code) if trimmedCode == "" { return ExecutionResult{ Result: nil, @@ -218,7 +219,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } } - // Step 3: Build tool bindings for all connected servers + // Step 2: Build tool bindings for all connected servers availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) serverKeys := make([]string, 0, len(availableToolsPerClient)) predeclared := starlark.StringDict{} @@ -254,9 +255,8 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } originalToolName := tool.Function.Name - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - parsedToolName := parseToolName(unprefixedToolName) + parsedToolName := getCanonicalToolName(clientName, originalToolName) + compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName) s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName) @@ -298,6 +298,13 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi }) structMembers[parsedToolName] = toolFunc + + if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) { + if _, exists := structMembers[compatibilityAlias]; !exists { + structMembers[compatibilityAlias] = toolFunc + s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName) + } + } } // Create a struct for this server @@ -312,7 +319,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix) } - // Step 4: Create Starlark thread with print function and timeout + // Step 3: Create Starlark thread with print function and timeout toolExecutionTimeout := s.getToolExecutionTimeout() timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) defer cancel() @@ -324,11 +331,26 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi }, } - // Set up cancellation check + // Set up cancellation check β€” watch the context and cancel the Starlark + // thread so that infinite loops and other long-running scripts are interrupted + // when the execution timeout fires. thread.SetLocal("context", timeoutCtx) + go func() { + <-timeoutCtx.Done() + thread.Cancel(timeoutCtx.Err().Error()) + }() + + // Step 4: Configure Starlark dialect options for a Python-like experience + starlarkOpts := &syntax.FileOptions{ + TopLevelControl: true, // allow if/for/while at top level (not just inside functions) + While: true, // enable while loops + Set: true, // enable set() builtin + GlobalReassign: true, // allow reassignment to top-level names + Recursion: true, // allow recursive functions + } // Step 5: Execute the code - globals, err := starlark.ExecFile(thread, "code.star", trimmedCode, predeclared) + globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared) if err != nil { errorMessage := err.Error() @@ -372,7 +394,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } // callMCPTool calls an MCP tool and returns the result. -func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { +func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { // Get available tools per client availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) @@ -400,29 +422,25 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Strip the client name prefix from tool name before calling MCP server originalToolName := stripClientPrefix(toolName, clientName) - // Get BifrostContext for plugin pipeline - var bifrostCtx *schemas.BifrostContext - var ok bool - if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) + if !ok { + originalRequestID = "" } - originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string) - // Generate new request ID for this nested tool call var newRequestID string if s.fetchNewRequestIDFunc != nil { - newRequestID = s.fetchNewRequestIDFunc(bifrostCtx) + newRequestID = s.fetchNewRequestIDFunc(ctx) } else { newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName) } // Create new child context - deadline, hasDeadline := bifrostCtx.Deadline() + deadline, hasDeadline := ctx.Deadline() if !hasDeadline { deadline = schemas.NoDeadline } - nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline) + nestedCtx := schemas.NewBifrostContext(ctx, deadline) nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID) if originalRequestID != "" { nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID) @@ -451,13 +469,17 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Check if plugin pipeline is available if s.pluginPipelineProvider == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline provider is nil") } // Get plugin pipeline and run hooks pipeline := s.pluginPipelineProvider() if pipeline == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline is nil") } defer s.releasePluginPipeline(pipeline) @@ -515,14 +537,7 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName Name: toolNameToCall, Arguments: args, }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers + Header: utils.GetHeadersForToolExecution(nestedCtx, client), } toolExecutionTimeout := s.getToolExecutionTimeout() @@ -604,57 +619,3 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName return nil, fmt.Errorf("plugin post-hooks returned invalid response") } - -// callMCPToolDirect executes an MCP tool call directly without plugin hooks. -func (s *StarlarkCodeMode) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: originalToolName, - Arguments: args, - }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers - } - - toolExecutionTimeout := s.getToolExecutionTimeout() - toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() - - logToolName := stripClientPrefix(toolName, clientName) - logToolName = strings.ReplaceAll(logToolName, "-", "_") - - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - if callErr != nil { - s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, logToolName, callErr) - appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr)) - return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr) - } - - rawResult := extractTextFromMCPResponse(toolResponse, toolName) - - if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { - errorMsg := after - s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, logToolName, errorMsg) - appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg)) - return nil, fmt.Errorf("%s", errorMsg) - } - - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - finalResult = rawResult - } - - resultStr := formatResultForLog(finalResult) - appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) - - return finalResult, nil -} diff --git a/core/mcp/codemode/starlark/getdocs.go b/core/mcp/codemode/starlark/getdocs.go index 61e4b1dc86..ea622bf7cf 100644 --- a/core/mcp/codemode/starlark/getdocs.go +++ b/core/mcp/codemode/starlark/getdocs.go @@ -71,8 +71,6 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem var matchedTool *schemas.ChatTool serverNameLower := strings.ToLower(serverName) - toolNameLower := strings.ToLower(toolName) - for clientName, tools := range availableToolsPerClient { client := s.clientManager.GetClientByName(clientName) if client == nil { @@ -90,10 +88,7 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem // Find the specific tool for i, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { + if matchesToolReference(toolName, clientName, tool.Function.Name) { matchedTool = &tools[i] break } @@ -125,9 +120,7 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem var availableTools []string for _, tool := range tools { if tool.Function != nil { - unprefixedToolName := stripClientPrefix(tool.Function.Name, matchedClientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) + availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name)) } } errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName) @@ -150,7 +143,7 @@ func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isTool // Write comprehensive header sb.WriteString("# ============================================================================\n") if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, tools[0].Function.Name)) + sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name))) } else { sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName)) } @@ -187,9 +180,7 @@ func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isTool } originalToolName := tool.Function.Name - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - toolName := parseToolName(unprefixedToolName) + toolName := getCanonicalToolName(clientName, originalToolName) description := "" if tool.Function.Description != nil { description = *tool.Function.Description diff --git a/core/mcp/codemode/starlark/listfiles.go b/core/mcp/codemode/starlark/listfiles.go index caff015194..4d6aa73add 100644 --- a/core/mcp/codemode/starlark/listfiles.go +++ b/core/mcp/codemode/starlark/listfiles.go @@ -21,7 +21,8 @@ func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { if bindingLevel == schemas.CodeModeBindingLevelServer { description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " + "Each server has a corresponding file (e.g., servers/.pyi) that contains compact Python signatures for all tools in that server. " + - "Use readToolFile to read a specific server file and see all available tools with their signatures. " + + "Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " + "Use getToolDocs if you need detailed documentation for a specific tool. " + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + @@ -30,7 +31,9 @@ func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { } else { description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " + "Each tool has a corresponding file (e.g., servers//.pyi) that contains compact Python signatures for that specific tool. " + - "Use readToolFile to read a specific tool file and see its signature. " + + "The shown in each filename is the exact canonical identifier exposed in executeToolCode. " + + "Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " + "Use getToolDocs if you need detailed documentation for a specific tool. " + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + @@ -88,12 +91,7 @@ func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall sch // Tool-level: one file per tool for _, tool := range tools { if tool.Function != nil && tool.Function.Name != "" { - // Strip the client prefix from tool name (format: "client-toolname" -> "toolname") - // But replace - with _ for valid Python identifiers - toolName := stripClientPrefix(tool.Function.Name, clientName) - // Replace any remaining hyphens with underscores for Python compatibility - toolName = strings.ReplaceAll(toolName, "-", "_") - // Validate normalized tool name to prevent path traversal + toolName := getCanonicalToolName(clientName, tool.Function.Name) if err := validateNormalizedToolName(toolName); err != nil { s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err) continue @@ -112,10 +110,32 @@ func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall sch } // Build tree structure from file list - responseText := buildVFSTree(files) + responseText := buildListToolFilesResponse(files, bindingLevel) return createToolResponseMessage(toolCall, responseText), nil } +func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string { + tree := buildVFSTree(files) + if tree == "" { + return "" + } + + header := []string{ + "# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode", + } + + if bindingLevel == schemas.CodeModeBindingLevelServer { + header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.") + } else { + header = append(header, + "# Filenames below use the exact canonical tool identifiers available in executeToolCode.", + "# Still call readToolFile before executeToolCode to confirm parameters and return shape.", + ) + } + + return strings.Join(append(header, "", tree), "\n") +} + // VFS tree node structure for building hierarchical file structure type treeNode struct { isDirectory bool diff --git a/core/mcp/codemode/starlark/readfile.go b/core/mcp/codemode/starlark/readfile.go index 5940e0c9fc..41063ad065 100644 --- a/core/mcp/codemode/starlark/readfile.go +++ b/core/mcp/codemode/starlark/readfile.go @@ -21,21 +21,23 @@ func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { var fileNameDescription, toolDescription string if bindingLevel == schemas.CodeModeBindingLevelServer { - fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'calculator.pyi')" + fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'servers/calculator.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " + "for all tools available on that server. The fileName should be in format servers/.pyi as listed by listToolFiles. " + "The function performs case-insensitive matching and removes the .pyi extension. " + + "This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " + "Each tool can be accessed in code via: serverName.tool_name(param=value). " + "If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + "do NOT call this tool again with startLine/endLine - you already have the complete file." } else { - fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'calculator/add.pyi')" + fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'servers/calculator/add.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " + "The fileName should be in format servers//.pyi as listed by listToolFiles. " + "The function performs case-insensitive matching and removes the .pyi extension. " + - "The tool can be accessed in code via: serverName.tool_name(param=value). " + + "This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " + + "The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " + "If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + @@ -126,13 +128,9 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche if isToolLevel { // Tool-level: filter to specific tool var foundTool *schemas.ChatTool - toolNameLower := strings.ToLower(toolName) for i, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { + if matchesToolReference(toolName, clientName, tool.Function.Name) { foundTool = &tools[i] break } @@ -143,15 +141,12 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche availableTools := make([]string, 0) for _, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) + availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name)) } } errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) for _, t := range availableTools { - errorMsg += fmt.Sprintf(" - %s/%s.pyi\n", clientName, t) + errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t) } return createToolResponseMessage(toolCall, errorMsg), nil } @@ -171,17 +166,14 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche for name := range availableToolsPerClient { if bindingLevel == schemas.CodeModeBindingLevelServer { - availableFiles = append(availableFiles, fmt.Sprintf("%s.pyi", name)) + availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name)) } else { client := s.clientManager.GetClientByName(name) if client != nil && client.ExecutionConfig.IsCodeModeClient { if tools, ok := availableToolsPerClient[name]; ok { for _, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, name) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.pyi", name, unprefixedToolName)) + availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name))) } } } @@ -295,12 +287,14 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo // Minimal header if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - toolName := parseToolName(stripClientPrefix(tools[0].Function.Name, clientName)) + toolName := getCanonicalToolName(clientName, tools[0].Function.Name) sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName)) } else { sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName)) } sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName)) + sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n") + sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n") sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName)) sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n") @@ -309,10 +303,7 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo continue } - // Strip client prefix and replace - with _ for code mode compatibility - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - toolName := parseToolName(unprefixedToolName) + toolName := getCanonicalToolName(clientName, tool.Function.Name) // Format inline parameters in Python style params := formatPythonParams(tool.Function.Parameters) diff --git a/core/mcp/codemode/starlark/starlark.go b/core/mcp/codemode/starlark/starlark.go index 0da1d2ccd9..348655b983 100644 --- a/core/mcp/codemode/starlark/starlark.go +++ b/core/mcp/codemode/starlark/starlark.go @@ -6,7 +6,6 @@ package starlark import ( - "context" "fmt" "sync" "sync/atomic" @@ -111,7 +110,7 @@ func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool { // Returns: // - *schemas.ChatMessage: The tool response message // - error: Any error that occurred during execution -func (s *StarlarkCodeMode) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { if toolCall.Function.Name == nil { return nil, fmt.Errorf("tool call missing function name") } diff --git a/core/mcp/codemode/starlark/starlark_test.go b/core/mcp/codemode/starlark/starlark_test.go index dba557f88a..a48e77a887 100644 --- a/core/mcp/codemode/starlark/starlark_test.go +++ b/core/mcp/codemode/starlark/starlark_test.go @@ -3,13 +3,42 @@ package starlark import ( + "context" + "strings" "testing" + "time" "github.com/bytedance/sonic" + codemcp "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" + "go.starlark.net/syntax" ) +type testClientManager struct { + clients map[string]*schemas.MCPClientState + tools map[string][]schemas.ChatTool +} + +func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState { + for clientName, tools := range m.tools { + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name == toolName { + return m.clients[clientName] + } + } + } + return nil +} + +func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState { + return m.clients[clientName] +} + +func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + return m.tools +} + func TestStarlarkToGo(t *testing.T) { t.Run("Convert None", func(t *testing.T) { result := starlarkToGo(starlark.None) @@ -151,6 +180,83 @@ func TestGoToStarlark(t *testing.T) { }) } +func TestGetCanonicalToolName(t *testing.T) { + if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" { + t.Fatalf("expected canonical tool name search_repos, got %q", got) + } + + if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" { + t.Fatalf("expected canonical tool name _123add, got %q", got) + } +} + +func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) { + clientName := "github" + originalToolName := "github-SEARCH_REPOS" + + testCases := []string{ + "search_repos", + "SEARCH_REPOS", + } + + for _, toolRef := range testCases { + if !matchesToolReference(toolRef, clientName, originalToolName) { + t.Fatalf("expected %q to match %q", toolRef, originalToolName) + } + } +} + +func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) { + mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{ + BindingLevel: schemas.CodeModeBindingLevelTool, + ToolExecutionTimeout: time.Second, + }, nil) + + clientName := "github" + mode.clientManager = &testClientManager{ + clients: map[string]*schemas.MCPClientState{ + clientName: { + Name: clientName, + ExecutionConfig: &schemas.MCPClientConfig{ + Name: clientName, + IsCodeModeClient: true, + }, + }, + }, + tools: map[string][]schemas.ChatTool{ + clientName: { + { + Function: &schemas.ChatToolFunction{ + Name: "github-SEARCH_REPOS", + }, + }, + }, + }, + } + + msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{ + ID: schemas.Ptr("tool-call-1"), + }) + if err != nil { + t.Fatalf("handleListToolFiles returned error: %v", err) + } + + if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil { + t.Fatal("expected tool response content") + } + + content := *msg.Content.ContentStr + if !strings.Contains(content, "search_repos.pyi") { + t.Fatalf("expected canonical tool file path in response, got:\n%s", content) + } + if strings.Contains(content, "SEARCH_REPOS.pyi") { + t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content) + } + if !strings.Contains(content, "readToolFile before executeToolCode") { + t.Fatalf("expected workflow guidance in response, got:\n%s", content) + } +} + func TestGeneratePythonErrorHints(t *testing.T) { serverKeys := []string{"calculator", "weather"} @@ -161,13 +267,13 @@ func TestGeneratePythonErrorHints(t *testing.T) { } found := false for _, hint := range hints { - if containsAny(hint, "not defined", "undefined") { + if strings.Contains(hint, "Variable 'foo' is not defined.") { found = true break } } if !found { - t.Error("Expected hint about undefined variable") + t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints) } }) @@ -489,3 +595,405 @@ func TestFormatResultForLog(t *testing.T) { } }) } + +// starlarkOpts returns the FileOptions used by the code mode executor. +// Kept in sync with executecode.go to test the same dialect configuration. +func starlarkOpts() *syntax.FileOptions { + return &syntax.FileOptions{ + TopLevelControl: true, + While: true, + Set: true, + GlobalReassign: true, + Recursion: true, + } +} + +// execStarlark is a test helper that executes Starlark code with our dialect options +// and returns the globals and any error. +func execStarlark(code string) (starlark.StringDict, error) { + thread := &starlark.Thread{Name: "test"} + return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil) +} + +func TestStarlarkDialectOptions(t *testing.T) { + t.Run("Top-level for loop", func(t *testing.T) { + code := ` +items = [] +for i in range(3): + items.append(i) +result = items +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level for loop should work, got error: %v", err) + } + resultVal := globals["result"] + list, ok := resultVal.(*starlark.List) + if !ok { + t.Fatalf("Expected list, got %T", resultVal) + } + if list.Len() != 3 { + t.Errorf("Expected 3 items, got %d", list.Len()) + } + }) + + t.Run("Top-level if statement", func(t *testing.T) { + code := ` +x = 10 +if x > 5: + result = "big" +else: + result = "small" +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level if should work, got error: %v", err) + } + if globals["result"] != starlark.String("big") { + t.Errorf("Expected 'big', got %v", globals["result"]) + } + }) + + t.Run("Top-level while loop", func(t *testing.T) { + code := ` +count = 0 +while count < 5: + count += 1 +result = count +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level while loop should work, got error: %v", err) + } + resultVal := globals["result"] + if resultVal.String() != "5" { + t.Errorf("Expected 5, got %v", resultVal) + } + }) + + t.Run("While loop inside function", func(t *testing.T) { + code := ` +def countdown(n): + items = [] + while n > 0: + items.append(n) + n -= 1 + return items +result = countdown(3) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("While in function should work, got error: %v", err) + } + list := globals["result"].(*starlark.List) + if list.Len() != 3 { + t.Errorf("Expected 3 items, got %d", list.Len()) + } + }) + + t.Run("set() builtin", func(t *testing.T) { + code := ` +s = set([1, 2, 3, 2, 1]) +result = len(s) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("set() should work, got error: %v", err) + } + if globals["result"].String() != "3" { + t.Errorf("Expected 3 unique items, got %v", globals["result"]) + } + }) + + t.Run("Global variable reassignment", func(t *testing.T) { + code := ` +x = 1 +x = x + 1 +x = x * 3 +result = x +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Global reassignment should work, got error: %v", err) + } + if globals["result"].String() != "6" { + t.Errorf("Expected 6, got %v", globals["result"]) + } + }) + + t.Run("Recursive function", func(t *testing.T) { + code := ` +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) +result = factorial(5) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Recursion should work, got error: %v", err) + } + if globals["result"].String() != "120" { + t.Errorf("Expected 120, got %v", globals["result"]) + } + }) +} + +func TestStarlarkStringEscapePreservation(t *testing.T) { + t.Run("Backslash-n in string literal preserved", func(t *testing.T) { + // Simulate what happens after JSON deserialization: + // Model writes: {"code": "msg = \"hello\\nworld\""} + // sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n) + // Starlark should interpret \n as newline escape inside the string + code := "msg = \"hello\\nworld\"\nresult = msg" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("String with \\n escape should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "hello\nworld" { + t.Errorf("Expected 'helloworld', got %q", resultStr) + } + }) + + t.Run("Multiple escape sequences in strings", func(t *testing.T) { + code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("String with multiple escapes should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "col1\tcol2\nrow1\trow2" { + t.Errorf("Expected tab/newline escapes, got %q", resultStr) + } + }) + + t.Run("Newline join pattern", func(t *testing.T) { + // This is the exact pattern that failed 7 times in benchmarks + code := ` +def main(): + lines = ["line1", "line2", "line3"] + content = "\n".join(lines) + return content +result = main() +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Newline join pattern should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "line1\nline2\nline3" { + t.Errorf("Expected joined lines, got %q", resultStr) + } + }) + + t.Run("chr() for newline", func(t *testing.T) { + code := ` +nl = chr(10) +result = "hello" + nl + "world" +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("chr(10) should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "hello\nworld" { + t.Errorf("Expected 'helloworld', got %q", resultStr) + } + }) + + t.Run("Triple-quoted strings", func(t *testing.T) { + code := "result = \"\"\"line1\nline2\nline3\"\"\"" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Triple-quoted string should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "line1\nline2\nline3" { + t.Errorf("Expected multiline string, got %q", resultStr) + } + }) + + t.Run("Raw string preserves backslash", func(t *testing.T) { + code := "result = r\"hello\\nworld\"" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Raw string should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + // Raw string: \n stays as two characters \ and n + if resultStr != "hello\\nworld" { + t.Errorf("Expected literal backslash-n, got %q", resultStr) + } + }) + + t.Run("JSON deserialization then Starlark execution", func(t *testing.T) { + // End-to-end: simulate the exact flow from model JSON β†’ sonic.Unmarshal β†’ Starlark + jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}` + + var arguments map[string]interface{} + err := sonic.Unmarshal([]byte(jsonArgs), &arguments) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + code := arguments["code"].(string) + + globals, starlarkErr := execStarlark(code) + if starlarkErr != nil { + t.Fatalf("Starlark execution failed: %v", starlarkErr) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "a\nb\nc" { + t.Errorf("Expected 'abc', got %q", resultStr) + } + }) +} + +func TestStarlarkUnsupportedFeatures(t *testing.T) { + t.Run("try/except rejected", func(t *testing.T) { + code := ` +def main(): + try: + x = 1 + except: + x = 0 +result = main() +` + _, err := execStarlark(code) + if err == nil { + t.Fatal("try/except should be rejected by Starlark") + } + if !strings.Contains(err.Error(), "got try") { + t.Errorf("Expected 'got try' in error, got: %v", err) + } + }) + + t.Run("raise rejected", func(t *testing.T) { + code := `raise ValueError("test")` + + _, err := execStarlark(code) + if err == nil { + t.Fatal("raise should be rejected by Starlark") + } + }) + + t.Run("class rejected", func(t *testing.T) { + code := ` +class Foo: + pass +` + _, err := execStarlark(code) + if err == nil { + t.Fatal("class should be rejected by Starlark") + } + }) + + t.Run("import rejected", func(t *testing.T) { + code := `import json` + + _, err := execStarlark(code) + if err == nil { + t.Fatal("import should be rejected by Starlark") + } + }) +} + +func TestGeneratePythonErrorHintsNewCases(t *testing.T) { + serverKeys := []string{"Github", "SqLite"} + + t.Run("try/except hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for try/except error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about try/except not being supported, got: %v", hints) + } + }) + + t.Run("except hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for except error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("finally hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for finally error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("raise hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for raise error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("Undefined variable includes scope hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for undefined variable") + } + foundVar := false + foundScope := false + for _, hint := range hints { + if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") { + foundVar = true + } + if containsAny(hint, "fresh scope", "persist") { + foundScope = true + } + } + if !foundVar { + t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints) + } + if !foundScope { + t.Errorf("Expected scope persistence hint, got: %v", hints) + } + }) +} diff --git a/core/mcp/codemode/starlark/utils.go b/core/mcp/codemode/starlark/utils.go index 5b7c9ab920..aea6e732d3 100644 --- a/core/mcp/codemode/starlark/utils.go +++ b/core/mcp/codemode/starlark/utils.go @@ -191,11 +191,25 @@ func formatResultForLog(result interface{}) string { func generatePythonErrorHints(errorMessage string, serverKeys []string) []string { hints := []string{} - if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { - re := regexp.MustCompile(`(\w+).*(?:undefined|not defined)`) - if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { - undefinedVar := match[1] + if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") || + strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") { + hints = append(hints, "Starlark does NOT support try/except/finally/raise β€” there is no exception handling.") + hints = append(hints, "Instead, check return values for errors:") + hints = append(hints, " result = server.tool(param=\"value\")") + hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):") + hints = append(hints, " print(\"Error:\", result)") + } else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { + var undefinedVar string + if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } + if undefinedVar != "" { hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar)) + hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope β€” no variables persist between calls.") if len(serverKeys) > 0 { hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")") @@ -298,7 +312,7 @@ func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, re } } -// parseToolName parses the tool name to be JavaScript-compatible. +// parseToolName normalizes a raw tool name into a Starlark-compatible identifier. func parseToolName(toolName string) string { if toolName == "" { return "" @@ -349,6 +363,61 @@ func parseToolName(toolName string) string { return parsed } +// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark. +func getCanonicalToolName(clientName, originalToolName string) string { + return parseToolName(stripClientPrefix(originalToolName, clientName)) +} + +// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name. +// This is used as a compatibility alias when the raw name is still a valid Starlark identifier. +func getCompatibilityToolAlias(clientName, originalToolName string) string { + return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_") +} + +// matchesToolReference reports whether the requested tool name matches any supported identifier form. +// We accept the canonical callable name plus legacy display forms for backward compatibility. +func matchesToolReference(requestedToolName, clientName, originalToolName string) bool { + requested := strings.ToLower(requestedToolName) + if requested == "" { + return false + } + + candidates := []string{ + getCanonicalToolName(clientName, originalToolName), + getCompatibilityToolAlias(clientName, originalToolName), + stripClientPrefix(originalToolName, clientName), + } + + for _, candidate := range candidates { + if candidate != "" && requested == strings.ToLower(candidate) { + return true + } + } + + return false +} + +// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code. +func isValidStarlarkIdentifier(name string) bool { + if name == "" { + return false + } + + runes := []rune(name) + first := runes[0] + if !unicode.IsLetter(first) && first != '_' && first != '$' { + return false + } + + for _, r := range runes[1:] { + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' { + return false + } + } + + return true +} + // validateNormalizedToolName validates a normalized tool name to prevent path traversal. func validateNormalizedToolName(normalizedName string) error { if normalizedName == "" { diff --git a/core/mcp/healthmonitor.go b/core/mcp/healthmonitor.go index aa6595fe7a..85769afdcb 100644 --- a/core/mcp/healthmonitor.go +++ b/core/mcp/healthmonitor.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" "github.com/maximhq/bifrost/core/schemas" ) @@ -140,9 +141,14 @@ func (chm *ClientHealthMonitor) performHealthCheck() { } chm.mu.Unlock() - // Get the client connection + // Get the client connection β€” capture Conn while holding the lock so we + // don't race with removeClientUnsafe zeroing it under the write lock. chm.manager.mu.RLock() clientState, exists := chm.manager.clientMap[chm.clientID] + var conn *client.Client + if exists && clientState != nil { + conn = clientState.Conn + } chm.manager.mu.RUnlock() if !exists { @@ -151,7 +157,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { } var err error - if clientState.Conn == nil { + if conn == nil { // No active connection β€” treat as a health check failure err = fmt.Errorf("no active connection") } else { @@ -160,7 +166,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { defer cancel() if chm.isPingAvailable { - err = clientState.Conn.Ping(ctx) + err = conn.Ping(ctx) } else { listRequest := mcp.ListToolsRequest{ PaginatedRequest: mcp.PaginatedRequest{ @@ -169,7 +175,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { }, }, } - _, err = clientState.Conn.ListTools(ctx, listRequest) + _, err = conn.ListTools(ctx, listRequest) } } diff --git a/core/mcp/interface.go b/core/mcp/interface.go index 93617ce511..3069e2692a 100644 --- a/core/mcp/interface.go +++ b/core/mcp/interface.go @@ -14,15 +14,17 @@ import ( type MCPManagerInterface interface { // Tool Operations // AddToolsToRequest parses available MCP tools and adds them to the request - AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest + AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest // GetAvailableTools returns all available MCP tools for the given context - GetAvailableTools(ctx context.Context) []schemas.ChatTool + GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool // ExecuteToolCall executes a single tool call and returns the result ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) - // UpdateToolManagerConfig updates the configuration for the tool manager + // UpdateToolManagerConfig updates the configuration for the tool manager. + // DisableAutoToolInject in the config controls auto injection β€” pass the + // current value whenever only other fields change so it is never silently reset. UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) // Agent Mode Operations @@ -60,6 +62,14 @@ type MCPManagerInterface interface { // ReconnectClient reconnects an MCP client by ID ReconnectClient(id string) error + // VerifyPerUserOAuthConnection creates a temporary MCP connection using a + // test access token to verify connectivity and discover tools. The connection + // is closed after verification. + VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) + + // SetClientTools updates the tool map and name mapping for an existing client. + SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) + // Tool Registration // RegisterTool registers a local tool with the MCP server RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index b86409ef1a..adcf3158ac 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -21,13 +21,6 @@ const ( BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment - - // Context keys for client filtering in requests - // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). - // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. - // Request context filtering takes priority over client config - context can override client exclusions. - MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering - MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) ) // ============================================================================ @@ -110,7 +103,7 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider } } - manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, logger) + manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, oauth2Provider, logger) // Set up CodeMode if provided - inject dependencies after manager is created if codeMode != nil { @@ -149,7 +142,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected } manager.mu.Unlock() - monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, clientConfig.IsPingAvailable, manager.logger) + isPingAvailable := true + if clientConfig.IsPingAvailable != nil { + isPingAvailable = *clientConfig.IsPingAvailable + } + monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger) manager.healthMonitorManager.StartMonitoring(monitor) } }(clientConfig) @@ -160,6 +157,13 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider return manager } +// SetPluginPipeline updates the plugin pipeline provider and release function on the manager's +// ToolsManager and CodeMode. Call this after attaching an externally-created MCPManager to a Bifrost +// instance so that nested tool calls in code mode can run through Bifrost's plugin hooks. +func (manager *MCPManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) { + manager.toolsManager.SetPluginPipeline(provider, release) +} + // AddToolsToRequest parses available MCP tools from the context and adds them to the request. // It respects context-based filtering for clients and tools, and returns the modified request // with tools attached. @@ -170,11 +174,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider // // Returns: // - *schemas.BifrostRequest: The request with tools added -func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { +func (m *MCPManager) AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest { return m.toolsManager.ParseAndAddToolsToRequest(ctx, req) } -func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { +func (m *MCPManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool { return m.toolsManager.GetAvailableTools(ctx) } diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index 8ba68e65ab..029d3fffd4 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -5,13 +5,17 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" "sync/atomic" "time" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -31,11 +35,15 @@ type PluginPipeline interface { // ToolsManager manages MCP tool execution and agent mode. type ToolsManager struct { - toolExecutionTimeout atomic.Value - maxAgentDepth atomic.Int32 - clientManager ClientManager - logger schemas.Logger - agentModeExecutor *AgentModeExecutor + toolExecutionTimeout atomic.Value + maxAgentDepth atomic.Int32 + disableAutoToolInject atomic.Bool + clientManager ClientManager + logger schemas.Logger + agentModeExecutor *AgentModeExecutor + + // OAuth2Provider for per-user OAuth token management + oauth2Provider schemas.OAuth2Provider // CodeMode implementation for code execution (Starlark by default) codeMode CodeMode @@ -73,6 +81,7 @@ func NewToolsManager( fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), + oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, ) *ToolsManager { return NewToolsManagerWithCodeMode( @@ -82,6 +91,7 @@ func NewToolsManager( pluginPipelineProvider, releasePluginPipeline, nil, // Use default code mode (will be set later via SetCodeMode) + oauth2Provider, logger, ) } @@ -106,6 +116,7 @@ func NewToolsManagerWithCodeMode( pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), codeMode CodeMode, + oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, ) *ToolsManager { if config == nil { @@ -142,11 +153,13 @@ func NewToolsManagerWithCodeMode( codeMode: codeMode, logger: logger, agentModeExecutor: agentModeExecutor, + oauth2Provider: oauth2Provider, } // Initialize atomic values manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + manager.disableAutoToolInject.Store(config.DisableAutoToolInject) manager.logger.Info("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) return manager @@ -174,8 +187,20 @@ func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies { } } +// SetPluginPipeline updates the plugin pipeline provider and release function +// on both the ToolsManager and its CodeMode implementation. +// This is used when an externally-created MCPManager is attached to a Bifrost instance +// via SetMCPManager, so the CodeMode can route nested tool calls through Bifrost's plugin hooks. +func (m *ToolsManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) { + m.pluginPipelineProvider = provider + m.releasePluginPipeline = release + if m.codeMode != nil { + m.codeMode.SetDependencies(m.GetCodeModeDependencies()) + } +} + // GetAvailableTools returns the available tools for the given context. -func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { +func (m *ToolsManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool { availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) // Flatten tools from all clients into a single slice, avoiding duplicates var availableTools []schemas.ChatTool @@ -191,14 +216,14 @@ func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool } if client.ExecutionConfig.IsCodeModeClient { includeCodeModeTools = true - } else { - // Add tools from this client, checking for duplicates - for _, tool := range clientTools { - if tool.Function != nil && tool.Function.Name != "" { - if !seenToolNames[tool.Function.Name] { - availableTools = append(availableTools, tool) - seenToolNames[tool.Function.Name] = true - } + } + // Add tools from this client, checking for duplicates + for _, tool := range clientTools { + if tool.Function != nil && tool.Function.Name != "" && !seenToolNames[tool.Function.Name] { + seenToolNames[tool.Function.Name] = true + schemas.AppendToContextList(ctx, schemas.BifrostContextKeyMCPAddedTools, tool.Function.Name) + if !client.ExecutionConfig.IsCodeModeClient { + availableTools = append(availableTools, tool) } } } @@ -288,12 +313,22 @@ func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integra // // Returns: // - *schemas.BifrostRequest: Bifrost request with MCP tools added -func (m *ToolsManager) ParseAndAddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { +func (m *ToolsManager) ParseAndAddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest { // MCP is only supported for chat and responses requests if req.ChatRequest == nil && req.ResponsesRequest == nil { return req } + // When auto tool injection is disabled, only inject tools if the request + // has explicit context filters set (e.g. via x-bf-mcp-include-tools header). + if m.disableAutoToolInject.Load() { + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) + includeClients := ctx.Value(schemas.MCPContextKeyIncludeClients) + if includeTools == nil && includeClients == nil { + return req + } + } + availableTools := m.GetAvailableTools(ctx) if len(availableTools) == 0 { @@ -541,9 +576,90 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall Name: originalMCPToolName, Arguments: arguments, }, + Header: utils.GetHeadersForToolExecution(ctx, client), } - if client.ExecutionConfig.Headers != nil { + // Handle per-user OAuth: inject user-specific Authorization header + if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth { + if m.oauth2Provider == nil { + return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured") + } + virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) + userID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceUserID).(string) + sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) + + // Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key + if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" { + userID = mcpUserID + } + + // Try identity-based token lookup first (works even without session token) + accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID) + if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) { + // Had session but token lookup failed with a real error (not just "not found") β€” return error + return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err) + } + if err != nil { + // No token found β€” user hasn't authenticated with this MCP server yet. + // In LLM gateway mode with no identity, we can't track who this user is, + // so an OAuth flow would produce an orphaned token. Return a clear error instead. + isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool) + if !isMCPGateway && userID == "" && virtualKeyID == "" { + return nil, "", "", fmt.Errorf( + "per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you", + client.ExecutionConfig.Name, + ) + } + + // Initiate OAuth flow to get a proper authorize URL with session tracking. + if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" { + return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name) + } + redirectURI := buildRedirectURIFromContext(ctx) + if redirectURI == "" { + return nil, "", "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context") + } + flowInitiation, sessionID, flowErr := m.oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI) + if flowErr != nil { + return nil, "", "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr) + } + return nil, "", "", &schemas.MCPUserOAuthRequiredError{ + MCPClientID: client.ExecutionConfig.ID, + MCPClientName: client.ExecutionConfig.Name, + AuthorizeURL: flowInitiation.AuthorizeURL, + SessionID: sessionID, + Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name), + } + } + + if client.Conn == nil { + // No persistent connection β€” create temporary connection with user's token + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := executeToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger) + if callErr != nil { + if toolCtx.Err() == context.DeadlineExceeded { + return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr) + } + responseText := extractTextFromMCPResponse(toolResponse, toolName) + return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil + } + + // Persistent connection exists β€” use per-call headers + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + headers.Set("Authorization", "Bearer "+accessToken) + callRequest.Header = headers + } else if client.ExecutionConfig.Headers != nil { headers := make(http.Header) for key, value := range client.ExecutionConfig.Headers { headers.Add(key, value.GetValue()) @@ -660,17 +776,99 @@ func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) } - // Update CodeMode configuration if present - if m.codeMode != nil && config.CodeModeBindingLevel != "" { + // Update CodeMode configuration β€” propagate whenever either field is set + if m.codeMode != nil && (config.CodeModeBindingLevel != "" || config.ToolExecutionTimeout > 0) { m.codeMode.UpdateConfig(&CodeModeConfig{ BindingLevel: config.CodeModeBindingLevel, ToolExecutionTimeout: config.ToolExecutionTimeout, }) } + m.disableAutoToolInject.Store(config.DisableAutoToolInject) + m.logger.Info("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) } +// executeToolWithUserToken creates a temporary MCP connection using the user's +// OAuth access token, calls the specified tool, and closes the connection. +// This is used for per_user_oauth clients which have no persistent connection β€” +// each tool call gets its own short-lived connection authenticated with the +// requesting user's token. +// +// Parameters: +// - ctx: context with timeout for the entire operation +// - config: MCP client configuration (connection URL, name) +// - toolName: original MCP tool name to call +// - arguments: tool call arguments +// - accessToken: user's OAuth access token +// - logger: logger instance +// +// Returns: +// - *mcp.CallToolResult: tool execution result +// - error: any error during connection or execution +func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) { + if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { + return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution") + } + + // Create HTTP transport with the user's Bearer token, preserving configured headers + headers := make(map[string]string) + if config.Headers != nil { + for key, value := range config.Headers { + headers[key] = value.GetValue() + } + } + headers["Authorization"] = "Bearer " + accessToken + httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + // Create temporary MCP client + tempClient := client.NewClient(httpTransport) + if err := tempClient.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start temporary MCP connection: %w", err) + } + defer tempClient.Close() + + // Initialize MCP handshake + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s-user", config.Name), + Version: "1.0.0", + }, + }, + } + if _, err := tempClient.Initialize(ctx, initRequest); err != nil { + return nil, fmt.Errorf("failed to initialize temporary MCP connection: %w", err) + } + + // Call the tool + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }, + } + return tempClient.CallTool(ctx, callRequest) +} + +// buildRedirectURIFromContext extracts the OAuth redirect URI from context. +// The URI is set by the HTTP middleware from the request's host. +func buildRedirectURIFromContext(ctx *schemas.BifrostContext) string { + if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" { + return uri + } + // Fallback β€” should not happen if middleware is configured correctly + return "" +} + // GetCodeModeBindingLevel returns the current code mode binding level. // This method is safe to call concurrently from multiple goroutines. func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel { diff --git a/core/mcp/utils.go b/core/mcp/utils.go index d80ec17acc..3359b564a4 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -65,7 +65,7 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. var includeClients []string // Extract client filtering from request context - if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { includeClients = existingIncludeClients } @@ -381,12 +381,12 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed - if len(config.ToolsToExecute) == 0 { + if config.ToolsToExecute.IsEmpty() { return true // No tools allowed } // Handle wildcard "*" - if present, all tools are allowed - if slices.Contains(config.ToolsToExecute, "*") { + if config.ToolsToExecute.IsUnrestricted() { return false // All tools allowed } @@ -396,7 +396,7 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the allowed list - return !slices.Contains(config.ToolsToExecute, unprefixedToolName) // Tool not in allowed list + return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list } return true // Tool is skipped (nil is treated as [] - no tools) @@ -413,12 +413,12 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // If ToolsToAutoExecute is specified (not nil), apply filtering if config.ToolsToAutoExecute != nil { // Handle empty array [] - means no tools are auto-executed - if len(config.ToolsToAutoExecute) == 0 { + if config.ToolsToAutoExecute.IsEmpty() { return false // No tools auto-executed } // Handle wildcard "*" - if present, all tools are auto-executed - if slices.Contains(config.ToolsToAutoExecute, "*") { + if config.ToolsToAutoExecute.IsUnrestricted() { return true // All tools auto-executed } @@ -428,7 +428,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the auto-execute list - return slices.Contains(config.ToolsToAutoExecute, unprefixedToolName) + return config.ToolsToAutoExecute.Contains(unprefixedToolName) } return false // Tool is not auto-executed (nil is treated as [] - no tools) @@ -439,7 +439,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // Context filtering can only NARROW the tools available, NOT expand beyond client configuration. // This is checked AFTER client-level filtering (shouldSkipToolForConfig). func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { - includeTools := ctx.Value(MCPContextKeyIncludeTools) + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) if includeTools != nil { // Try []string first (preferred type) @@ -754,6 +754,7 @@ func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { return true } + // Check if message has tool calls regardless of finish_reason. // Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present, // so we cannot rely solely on finish_reason to detect tool calls. diff --git a/core/mcp/utils/utils.go b/core/mcp/utils/utils.go new file mode 100644 index 0000000000..500792a09f --- /dev/null +++ b/core/mcp/utils/utils.go @@ -0,0 +1,49 @@ +package utils + +import ( + "net/http" + + "github.com/maximhq/bifrost/core/schemas" +) + +// GetHeadersForToolExecution sets additional headers for tool execution. +// It returns the headers for the tool execution. +func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header { + if ctx == nil || client == nil || client.ExecutionConfig == nil { + return make(http.Header) + } + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + // Give priority to extra headers in the context + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok { + filteredHeaders := make(http.Header) + for key, values := range extraHeaders { + if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) { + for i, value := range values { + if i == 0 { + filteredHeaders.Set(key, value) + } else { + filteredHeaders.Add(key, value) + } + } + } + } + // Add the filtered headers to the headers + if len(filteredHeaders) > 0 { + for k, values := range filteredHeaders { + for i, v := range values { + if i == 0 { + headers.Set(k, v) + } else { + headers.Add(k, v) + } + } + } + } + } + return headers +} diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index 98abd9b7e1..ca683f26e0 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -173,7 +173,7 @@ func extractAnthropicResponsesUsageFromPrefetch(data []byte) *schemas.ResponsesR // Returns the response body or an error if the request fails. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, latency, nil) β€” callers must check the context flag. -func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, requestType schemas.RequestType) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -208,7 +208,7 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, requestClient := provider.client responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) - isCountTokens := meta != nil && meta.RequestType == schemas.CountTokensRequest + isCountTokens := requestType == schemas.CountTokensRequest // CountTokens responses are always tiny β€” skip streaming client so the response // is buffered normally (same approach as OpenAI and Gemini count_tokens handlers). if responseThreshold > 0 && !isCountTokens { @@ -233,20 +233,20 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) - return nil, latency, providerResponseHeaders, parseAnthropicError(resp, meta) + return nil, latency, providerResponseHeaders, parseAnthropicError(resp) } // CountTokens uses buffered response (streaming skipped above) β€” decode directly. if isCountTokens { body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return body, latency, providerResponseHeaders, nil } // Delegate large response detection + normal buffered path to shared utility - body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { return nil, latency, providerResponseHeaders, respErr } @@ -290,10 +290,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseAnthropicError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseAnthropicError(resp) } // Parse Anthropic's response @@ -304,7 +301,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, } // Create final response - response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered) + response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -355,18 +352,13 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToAnthropicTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use struct directly for JSON marshaling (no beta headers for text completion) - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TextCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), schemas.TextCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -379,9 +371,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -400,9 +389,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostTextCompletionResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -444,18 +430,13 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k } AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use struct directly for JSON marshaling - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), schemas.ChatCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -468,9 +449,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -489,9 +467,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostChatResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -528,8 +503,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont anthropicReq.Stream = schemas.Ptr(true) AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -563,11 +537,6 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } @@ -587,7 +556,6 @@ func HandleAnthropicChatCompletionStreaming( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -634,9 +602,9 @@ func HandleAnthropicChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -645,7 +613,7 @@ func HandleAnthropicChatCompletionStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -661,14 +629,10 @@ func HandleAnthropicChatCompletionStreaming( // Start streaming in a goroutine go func() { defer func() { - model := "unknown" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -678,7 +642,6 @@ func HandleAnthropicChatCompletionStreaming( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -732,7 +695,7 @@ func HandleAnthropicChatCompletionStreaming( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -791,7 +754,6 @@ func HandleAnthropicChatCompletionStreaming( } } if event.Message != nil { - // Handle different event types modelName = event.Message.Model } @@ -840,11 +802,8 @@ func HandleAnthropicChatCompletionStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -868,22 +827,14 @@ func HandleAnthropicChatCompletionStreaming( response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream(ctx, structuredOutputToolName, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break } if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -910,7 +861,7 @@ func HandleAnthropicChatCompletionStreaming( usage.PromptTokens = usage.PromptTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens usage.TotalTokens = usage.TotalTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, modelName, 0) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -939,16 +890,12 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, nil) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), schemas.ResponsesRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -966,9 +913,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc Model: request.Model, Usage: extractAnthropicResponsesUsageFromPrefetch([]byte(preview)), ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -988,9 +932,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc bifrostResponse := response.ToBifrostResponsesResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1014,7 +955,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, } // Convert to Anthropic format using the centralized converter - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), true, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, true, nil) if err != nil { return nil, err } @@ -1047,11 +988,6 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } @@ -1071,7 +1007,6 @@ func HandleAnthropicResponsesStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -1120,9 +1055,9 @@ func HandleAnthropicResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1131,7 +1066,7 @@ func HandleAnthropicResponsesStream( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1147,14 +1082,10 @@ func HandleAnthropicResponsesStream( // Start streaming in a goroutine go func() { defer func() { - model := "" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1164,7 +1095,6 @@ func HandleAnthropicResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1216,7 +1146,7 @@ func HandleAnthropicResponsesStream( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1286,11 +1216,6 @@ func HandleAnthropicResponsesStream( ctx.SetValue(schemas.BifrostContextKeyHasEmittedMessageDelta, true) } if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { return @@ -1307,12 +1232,9 @@ func HandleAnthropicResponsesStream( Type: schemas.ResponsesStreamResponseType(eventType), SequenceNumber: chunkIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), - RawResponse: eventData, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + RawResponse: eventData, }, } lastChunkTime = time.Now() @@ -1326,11 +1248,8 @@ func HandleAnthropicResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -1384,7 +1303,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if len(request.Requests) == 0 { - return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil) } // Create request @@ -1422,7 +1341,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key jsonData, err := providerUtils.MarshalSorted(anthropicReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } usedLargePayloadBody := setAnthropicRequestBody(ctx, req, jsonData) @@ -1442,12 +1361,12 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var anthropicResp AnthropicBatchResponse @@ -1456,7 +1375,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return anthropicResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -1472,7 +1391,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Initialize serial pagination helper (Anthropic uses AfterID for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.AfterID, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -1483,10 +1402,6 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -1535,12 +1450,12 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicBatchListResponse @@ -1553,7 +1468,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(anthropicResp.Data)) var lastBatchID string for _, batch := range anthropicResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, false, false, nil, nil)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, false, false, nil, nil)) lastBatchID = batch.ID } @@ -1567,9 +1482,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -1587,7 +1500,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1628,7 +1541,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1640,7 +1553,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1658,8 +1571,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := anthropicResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := anthropicResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -1674,7 +1586,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1711,7 +1623,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1723,7 +1635,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1746,9 +1658,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys Object: anthropicResp.Type, Status: ToBifrostBatchStatus(anthropicResp.ProcessingStatus), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1791,7 +1701,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1825,7 +1735,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1837,7 +1747,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1879,9 +1789,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1959,7 +1867,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -1973,14 +1881,14 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2012,12 +1920,12 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileUploadRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var anthropicResp AnthropicFileResponse @@ -2028,7 +1936,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s return nil, bifrostErr } - return anthropicResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided keys and aggregates results. @@ -2046,7 +1954,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2057,10 +1965,6 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2106,12 +2010,12 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicFileListResponse @@ -2146,9 +2050,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2167,7 +2069,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2208,7 +2110,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2220,7 +2122,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2238,7 +2140,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - return anthropicResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } return nil, lastErr @@ -2253,7 +2155,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2290,7 +2192,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2307,9 +2209,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2319,7 +2219,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2342,9 +2242,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: anthropicResp.Type == "file_deleted", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2372,7 +2270,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -2404,7 +2302,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileContentRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2416,7 +2314,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2436,9 +2334,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2451,16 +2347,12 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, []string{"max_tokens", "temperature"}) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, []string{"max_tokens", "temperature"}) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }) + responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), schemas.CountTokensRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2484,9 +2376,6 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) response.Model = request.Model - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2621,7 +2510,7 @@ func (provider *AnthropicProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { @@ -2636,9 +2525,6 @@ func (provider *AnthropicProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2702,9 +2588,9 @@ func (provider *AnthropicProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2715,7 +2601,6 @@ func (provider *AnthropicProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -2727,11 +2612,7 @@ func (provider *AnthropicProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2742,9 +2623,9 @@ func (provider *AnthropicProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -2793,7 +2674,7 @@ func (provider *AnthropicProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/anthropic/batch.go b/core/providers/anthropic/batch.go index 405738330c..ac4b0940c4 100644 --- a/core/providers/anthropic/batch.go +++ b/core/providers/anthropic/batch.go @@ -3,9 +3,7 @@ package anthropic import ( "time" - providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" ) // Anthropic Batch API Types @@ -129,7 +127,7 @@ func ToBifrostObjectType(anthropicType string) string { } // ToBifrostBatchCreateResponse converts Anthropic batch response to Bifrost batch create response. -func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { expiresAt := parseAnthropicTimestamp(r.ExpiresAt) resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, @@ -140,9 +138,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExpiresAt: &expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -170,7 +166,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem } // ToBifrostBatchRetrieveResponse converts Anthropic batch response to Bifrost batch retrieve response. -func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: ToBifrostObjectType(r.Type), @@ -179,9 +175,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch ResultsURL: r.ResultsURL, CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -228,26 +222,6 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch return resp } -// ParseAnthropicError parses Anthropic error responses for batch operations. -func ParseAnthropicError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { - var errorResp AnthropicError - bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) - if errorResp.Error != nil { - if errorResp.Error.Type != "" { - bifrostErr.Error.Type = &errorResp.Error.Type - } - if errorResp.Error.Message != "" { - bifrostErr.Error.Message = errorResp.Error.Message - } - } - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - } - return bifrostErr -} - // ToAnthropicBatchCreateResponse converts a Bifrost batch create response to Anthropic format. func ToAnthropicBatchCreateResponse(resp *schemas.BifrostBatchCreateResponse) *AnthropicBatchResponse { result := &AnthropicBatchResponse{ diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index 93c0e7c1d0..fc5624616b 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -418,12 +418,8 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse(ctx *schemas.Bif // Initialize Bifrost response bifrostResponse := &schemas.BifrostChatResponse{ - ID: response.ID, - Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Anthropic, - }, + ID: response.ID, + Model: response.Model, Created: int(time.Now().Unix()), } diff --git a/core/providers/anthropic/chat_test.go b/core/providers/anthropic/chat_test.go index 4d0ea9ac45..0680bd1779 100644 --- a/core/providers/anthropic/chat_test.go +++ b/core/providers/anthropic/chat_test.go @@ -337,6 +337,115 @@ func TestToAnthropicChatRequest_ToolInputKeyOrderPreservation(t *testing.T) { } } +func TestToBifrostChatResponse_MultipleTextBlocksWithThinking(t *testing.T) { + thinkingText := "Let me reason step by step about this problem." + textBlock1 := "The answer is 42." + textBlock2 := "Here is why that is the case." + signature := "sig_abc123" + + response := &AnthropicMessageResponse{ + ID: "msg_test123", + Type: "message", + Role: "assistant", + Model: "claude-opus-4-6-20250514", + Content: []AnthropicContentBlock{ + { + Type: AnthropicContentBlockTypeThinking, + Thinking: &thinkingText, + Signature: &signature, + }, + { + Type: AnthropicContentBlockTypeText, + Text: &textBlock1, + }, + { + Type: AnthropicContentBlockTypeText, + Text: &textBlock2, + }, + }, + StopReason: "end_turn", + Usage: &AnthropicUsage{ + InputTokens: 100, + OutputTokens: 50, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result := response.ToBifrostChatResponse(ctx) + + if result == nil { + t.Fatal("expected non-nil result") + } + + // Content should be a combined string, not blocks + choice := result.Choices[0] + msg := choice.ChatNonStreamResponseChoice.Message + if msg.Content.ContentBlocks != nil { + t.Error("expected ContentBlocks to be nil (combined into string)") + } + if msg.Content.ContentStr == nil { + t.Fatal("expected ContentStr to be non-nil") + } + + // Combined string: thinking first, then text blocks + expected := thinkingText + "\n\n" + textBlock1 + "\n\n" + textBlock2 + if *msg.Content.ContentStr != expected { + t.Errorf("expected combined content:\n%s\ngot:\n%s", expected, *msg.Content.ContentStr) + } + + // Reasoning field should still have thinking text + if msg.ChatAssistantMessage == nil { + t.Fatal("expected ChatAssistantMessage to be non-nil") + } + if msg.ChatAssistantMessage.Reasoning == nil { + t.Fatal("expected Reasoning to be non-nil") + } + + // ReasoningDetails should have: signature-only thinking entry + content blocks boundary + rd := msg.ChatAssistantMessage.ReasoningDetails + if len(rd) < 2 { + t.Fatalf("expected at least 2 reasoning details entries, got %d", len(rd)) + } + + // First entry: thinking with signature, no text (text was cleared) + if rd[0].Type != schemas.BifrostReasoningDetailsTypeText { + t.Errorf("expected first reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type) + } + if rd[0].Signature == nil || *rd[0].Signature != signature { + t.Error("expected signature to be preserved") + } + if rd[0].Text != nil { + t.Error("expected thinking text to be nil (cleared to avoid duplication)") + } + + // Last entry: content blocks boundary + lastRD := rd[len(rd)-1] + if lastRD.Type != schemas.BifrostReasoningDetailsTypeContentBlocks { + t.Errorf("expected last reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeContentBlocks, lastRD.Type) + } + if lastRD.Text == nil { + t.Fatal("expected content blocks metadata to be non-nil") + } + + // var meta []contentBlockMeta + // if err := json.Unmarshal([]byte(*lastRD.Text), &meta); err != nil { + // t.Fatalf("failed to unmarshal block metadata: %v", err) + // } + // if len(meta) != 3 { + // t.Fatalf("expected 3 block metadata entries, got %d", len(meta)) + // } + // if meta[0].T != "thinking" || meta[0].L != len(thinkingText) { + // t.Errorf("block 0: expected thinking/%d, got %s/%d", len(thinkingText), meta[0].T, meta[0].L) + // } + // if meta[1].T != "text" || meta[1].L != len(textBlock1) { + // t.Errorf("block 1: expected text/%d, got %s/%d", len(textBlock1), meta[1].T, meta[1].L) + // } + // if meta[2].T != "text" || meta[2].L != len(textBlock2) { + // t.Errorf("block 2: expected text/%d, got %s/%d", len(textBlock2), meta[2].T, meta[2].L) + // } +} + func TestToBifrostChatResponse_SingleTextBlockNoThinking(t *testing.T) { // Verify existing behavior: single text block without thinking collapses to string text := "Simple response" diff --git a/core/providers/anthropic/errors.go b/core/providers/anthropic/errors.go index dd1dfaf698..81bbd49d0c 100644 --- a/core/providers/anthropic/errors.go +++ b/core/providers/anthropic/errors.go @@ -54,7 +54,7 @@ func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string { return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) } -func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp AnthropicError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Error != nil { @@ -64,10 +64,5 @@ func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMet bifrostErr.Error.Type = &errorResp.Error.Type bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index 3da2f6458b..3815a0244b 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -1,13 +1,14 @@ package anthropic import ( + "strings" "time" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -19,57 +20,51 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide HasMore: schemas.Ptr(response.HasMore), } - // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination - // If there are more results, set next_page_token to last_id so it can be used in the next request + // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination. + // If there are more results, set next_page_token to last_id for the next request. if response.HasMore && response.LastID != nil { bifrostResponse.NextPageToken = *response.LastID } - includedModels := make(map[string]bool) - for _, model := range response.Data { - modelID := model.ID - if !unfiltered && len(allowedModels) > 0 { - allowed := false - for _, allowedModel := range allowedModels { - if schemas.SameBaseModel(model.ID, allowedModel) { - modelID = allowedModel - allowed = true - break - } - } - if !allowed { - continue - } - } - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, modelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Created: schemas.Ptr(model.CreatedAt.Unix()), - MaxInputTokens: model.MaxInputTokens, - MaxOutputTokens: model.MaxTokens, - ProviderExtra: model.Capabilities, - }) - includedModels[modelID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) - // Backfill allowed models that were not in the response (skip blacklisted; blacklist wins over allow list) - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { continue } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Created: schemas.Ptr(model.CreatedAt.Unix()), + MaxInputTokens: model.MaxInputTokens, + MaxOutputTokens: model.MaxTokens, + ProviderExtra: model.Capabilities, } + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[resolvedKey] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index 5ec1c17a80..6b00918b74 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1429,13 +1429,13 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp if bifrostResp.Response.ID != nil { streamMessage.ID = *bifrostResp.Response.ID } - // Preserve model from Response if available, otherwise use ExtraFields - if bifrostResp.ExtraFields.ModelRequested != "" { - if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { - streamMessage.Model = bifrostResp.Response.Model - } else { - streamMessage.Model = bifrostResp.ExtraFields.ModelRequested - } + // Prefer Response.Model, then ResolvedModelUsed, then OriginalModelRequested + if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { + streamMessage.Model = bifrostResp.Response.Model + } else if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + streamMessage.Model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if bifrostResp.ExtraFields.OriginalModelRequested != "" { + streamMessage.Model = bifrostResp.ExtraFields.OriginalModelRequested } streamResp.Message = streamMessage } diff --git a/core/providers/anthropic/text.go b/core/providers/anthropic/text.go index 3228ad49f6..39a700499b 100644 --- a/core/providers/anthropic/text.go +++ b/core/providers/anthropic/text.go @@ -103,10 +103,6 @@ func (response *AnthropicTextResponse) ToBifrostTextCompletionResponse() *schema TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, }, Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Anthropic, - }, } } diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index f803c337e5..867dcf7e97 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -1248,7 +1248,7 @@ type AnthropicFileDeleteResponse struct { } // ToBifrostFileUploadResponse converts an Anthropic file response to Bifrost file upload response. -func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Type, @@ -1259,9 +1259,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1277,7 +1275,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas } // ToBifrostFileRetrieveResponse converts an Anthropic file response to Bifrost file retrieve response. -func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { +func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { resp := &schemas.BifrostFileRetrieveResponse{ ID: r.ID, Object: r.Type, @@ -1288,9 +1286,7 @@ func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index a0ff905568..c04453644e 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -136,7 +136,7 @@ func setEffortOnOutputConfig(req *AnthropicMessageRequest, effort string) { req.OutputConfig.Effort = &effort } -func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader in completeRequest/ // setAnthropicRequestBody β€” skip all body building here (matches CheckContextAndGetRequestBody). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -156,7 +156,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi _, model := schemas.ParseModelString(modelStr, schemas.Anthropic) jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", model) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -168,36 +168,36 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi } jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", defaultMaxTokens) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Strip auto-injectable server-side tools to prevent conflicts with API auto-injection jsonBody, err = StripAutoInjectableTools(jsonBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Remove excluded fields for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Convert request to Anthropic format reqBody, convErr := ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Anthropic) if isStreaming { @@ -206,7 +206,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -215,14 +215,14 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Use MergeExtraParamsIntoJSON which preserves key order jsonBody, err = providerUtils.MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Remove excluded fields after merging (using sjson to preserve order) for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else if len(excludeFields) > 0 { @@ -230,7 +230,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 3d1059c1c2..16d0a3c301 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -100,7 +100,7 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -109,11 +109,11 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, Scopes: scopes, }) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -138,16 +138,16 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -206,10 +206,8 @@ func (provider *AzureProvider) completeRequest( jsonData []byte, path string, key schemas.Key, - deployment string, model string, - requestType schemas.RequestType, -) ([]byte, string, time.Duration, map[string]string, *schemas.BifrostError) { +) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -222,7 +220,7 @@ func (provider *AzureProvider) completeRequest( }() var url string - isAnthropicModel := schemas.IsAnthropicModel(deployment) + isAnthropicModel := schemas.IsAnthropicModel(model) // Set any extra headers from network config. // For Anthropic models, exclude anthropic-beta β€” it is merged and filtered explicitly below. @@ -237,7 +235,7 @@ func (provider *AzureProvider) completeRequest( // Get authentication headers authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel) if bifrostErr != nil { - return nil, deployment, 0, nil, bifrostErr + return nil, 0, nil, bifrostErr } // Apply headers to request @@ -247,7 +245,7 @@ func (provider *AzureProvider) completeRequest( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, deployment, 0, nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, 0, nil, providerUtils.NewConfigurationError("endpoint not set") } if isAnthropicModel { @@ -282,7 +280,7 @@ func (provider *AzureProvider) completeRequest( latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp) defer wait() if bifrostErr != nil { - return nil, deployment, latency, nil, bifrostErr + return nil, latency, nil, bifrostErr } // Extract provider response headers before body is copied β€” do this before status check @@ -292,19 +290,20 @@ func (provider *AzureProvider) completeRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, deployment, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, requestType, provider.GetProviderKey(), model) + rawErrBody := append([]byte(nil), resp.Body()...) + return rawErrBody, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { - return nil, deployment, latency, providerResponseHeaders, decodeErr + return nil, latency, providerResponseHeaders, decodeErr } if isLargeResp { respOwned = false - return nil, deployment, latency, providerResponseHeaders, nil + return nil, latency, providerResponseHeaders, nil } - return body, deployment, latency, providerResponseHeaders, nil + return body, latency, providerResponseHeaders, nil } // listModelsByKey performs a list models request for a single key. @@ -312,11 +311,11 @@ func (provider *AzureProvider) completeRequest( func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { // Validate Azure key configuration if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", schemas.Azure) + return nil, providerUtils.NewConfigurationError("azure key config not set") } if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", schemas.Azure) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Get API version @@ -359,12 +358,12 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -379,9 +378,9 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert to Bifrost response - response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, key.BlacklistedModels, request.Unfiltered) + response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil) } response.ExtraFields.Latency = latency.Milliseconds() @@ -415,35 +414,23 @@ func (provider *AzureProvider) ListModels(ctx *schemas.BifrostContext, keys []sc // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized OpenAI text converter (Azure is OpenAI-compatible) jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAITextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/completions", deployment), + fmt.Sprintf("openai/deployments/%s/completions", request.Model), key, - deployment, request.Model, - schemas.TextCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -457,10 +444,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -474,10 +457,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -498,21 +477,12 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s // It formats the request, sends it to Azure, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -520,11 +490,6 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, return nil, err } - customPostResponseConverter := func(response *schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -538,7 +503,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, nil, postHookRunner, nil, - customPostResponseConverter, + nil, provider.logger, ) } @@ -547,26 +512,16 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { reqBody, err := anthropic.ToAnthropicChatRequest(ctx, request) if err != nil { return nil, err } if reqBody != nil { - reqBody.Model = deployment // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } @@ -574,27 +529,24 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } else { return openai.ToOpenAIChatRequest(ctx, request), nil } - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { - path = fmt.Sprintf("openai/deployments/%s/chat/completions", deployment) + path = fmt.Sprintf("openai/deployments/%s/chat/completions", request.Model) } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ChatCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -608,10 +560,6 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -622,7 +570,7 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -637,12 +585,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ChatCompletionRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -662,22 +606,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -694,14 +624,12 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, return nil, err } if reqBody != nil { - reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -719,13 +647,8 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -736,7 +659,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Use shared streaming logic from OpenAI return openai.HandleOpenAIChatCompletionStreaming( @@ -754,7 +677,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, nil, - postResponseConverter, + nil, provider.logger, ) } @@ -764,51 +687,36 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - var jsonData []byte var bifrostErr *schemas.BifrostError - if schemas.IsAnthropicModel(deployment) { - jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), false) + if schemas.IsAnthropicModel(request.Model) { + jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false) } else { jsonData, bifrostErr = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { reqBody := openai.ToOpenAIResponsesRequest(request) - if reqBody != nil { - reqBody.Model = deployment - } return reqBody, nil - }, - provider.GetProviderKey()) + }) } if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { path = "openai/v1/responses" } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ResponsesRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -822,10 +730,6 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -836,7 +740,7 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -851,12 +755,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ResponsesRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -873,22 +773,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema // ResponsesStream performs a streaming responses request to Azure's API. func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -896,7 +782,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint.GetValue()) - jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), true) + jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true) if bifrostErr != nil { return nil, bifrostErr } @@ -914,13 +800,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -929,11 +810,6 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post } url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint.GetValue()) - postRequestConverter := func(req *openai.OpenAIResponsesRequest) *openai.OpenAIResponsesRequest { - req.Model = deployment - return req - } - // Use shared streaming logic from OpenAI return openai.HandleOpenAIResponsesStreaming( ctx, @@ -948,8 +824,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -959,35 +835,23 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAIEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/embeddings", deployment), + fmt.Sprintf("openai/deployments/%s/embeddings", request.Model), key, - deployment, request.Model, - schemas.EmbeddingRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1001,10 +865,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1019,12 +879,8 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.EmbeddingRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1041,15 +897,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema // Speech is not supported by the Azure provider. func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1057,10 +904,10 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAISpeechRequest( ctx, @@ -1080,9 +927,6 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1094,15 +938,6 @@ func (provider *AzureProvider) Rerank(ctx *schemas.BifrostContext, key schemas.K // SpeechStream handles streaming for speech synthesis with Azure. // Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON. func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1113,7 +948,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1153,11 +988,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo reqBody := openai.ToOpenAISpeechRequest(request) if reqBody != nil { reqBody.StreamFormat = schemas.Ptr("sse") - reqBody.Model = deployment // Replace model with deployment } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1181,9 +1014,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(requestErr, fasthttp.ErrTimeout) || errors.Is(requestErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1192,7 +1025,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Create response channel @@ -1204,9 +1037,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1307,11 +1140,6 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo var bifrostErr schemas.BifrostError if errParseErr := sonic.Unmarshal(audioData, &bifrostErr); errParseErr == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -1333,12 +1161,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Set extra fields for the response response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -1367,7 +1191,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // a fake "done" response with truncated audio. ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1380,12 +1204,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo finalResponse := schemas.BifrostSpeechStreamResponse{ Type: schemas.SpeechStreamResponseTypeDone, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1408,21 +1228,12 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Transcription is not supported by the Azure provider. func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAITranscriptionRequest( ctx, @@ -1441,9 +1252,6 @@ func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key sc return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1457,16 +1265,6 @@ func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Returns a BifrostResponse containing the bifrost response or an error if the request fails. func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1474,13 +1272,13 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } response, err := openai.HandleOpenAIImageGenerationRequest( ctx, provider.client, - fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()), + fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()), request, key, provider.networkConfig.ExtraHeaders, @@ -1493,9 +1291,6 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1508,18 +1303,6 @@ func (provider *AzureProvider) ImageGenerationStream( key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - // - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1527,17 +1310,10 @@ func (provider *AzureProvider) ImageGenerationStream( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1558,7 +1334,7 @@ func (provider *AzureProvider) ImageGenerationStream( postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1566,16 +1342,6 @@ func (provider *AzureProvider) ImageGenerationStream( // ImageEdit performs an image edit request to Azure's API. func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1583,10 +1349,10 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAIImageEditRequest( ctx, provider.client, @@ -1603,24 +1369,11 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } // ImageEditStream performs a streaming image edit request to Azure's API. func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1628,17 +1381,10 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1659,7 +1405,7 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1673,30 +1419,19 @@ func (provider *AzureProvider) ImageVariation(ctx *schemas.BifrostContext, key s // VideoGeneration creates a video using Azure's OpenAI-compatible Sora API. // This delegates to the OpenAI handler with Azure-specific URL and authentication. func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, bifrostErr := provider.getModelDeployment(key, request.Model) - if bifrostErr != nil { - return nil, bifrostErr - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL for OpenAI-compatible video generation endpoint url := fmt.Sprintf("%s/openai/v1/videos", endpoint) - requestCopy := *request - requestCopy.Model = deployment response, bifrostErr := openai.HandleOpenAIVideoGenerationRequest( ctx, provider.client, url, - &requestCopy, + request, key, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), @@ -1708,27 +1443,20 @@ func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, nil } // VideoRetrieve retrieves the status of a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false) @@ -1754,20 +1482,16 @@ func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key sc // VideoDownload downloads video content from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Create request @@ -1803,13 +1527,12 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -1825,9 +1548,7 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc Content: append([]byte(nil), body...), ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1836,20 +1557,16 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // VideoDelete deletes a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1876,13 +1593,9 @@ func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key sche // VideoList lists videos from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoList(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1912,64 +1625,14 @@ func (provider *AzureProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.K return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey()) } -// validateKeyConfig validates the key configuration. -// It checks if the key config is set, the endpoint is set, and the deployments are set. -// Returns an error if any of the checks fail. -func (provider *AzureProvider) validateKeyConfig(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments == nil { - return providerUtils.NewConfigurationError("deployments not set", provider.GetProviderKey()) - } - - return nil -} - -// validateKeyConfigForFiles validates key config for file/batch APIs, which only -// require a configured Azure endpoint (no per-model deployments needed). -func (provider *AzureProvider) validateKeyConfigForFiles(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - return nil -} - -func (provider *AzureProvider) getModelDeployment(key schemas.Key, model string) (string, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return "", providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments != nil { - if deployment, ok := key.AzureKeyConfig.Deployments[model]; ok { - return deployment, nil - } - } - return "", providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", model), provider.GetProviderKey()) -} - // FileUpload uploads a file to Azure OpenAI. func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Get API version @@ -1984,7 +1647,7 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add file field @@ -1994,14 +1657,14 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2038,13 +1701,12 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp openai.OpenAIFileResponse @@ -2055,17 +1717,15 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem return nil, bifrostErr } - return openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided Azure keys and aggregates results. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2074,7 +1734,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2085,18 +1745,9 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2144,13 +1795,12 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIFileListResponse @@ -2185,9 +1835,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2202,7 +1850,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2210,11 +1858,6 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2257,8 +1900,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2270,7 +1912,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2296,14 +1938,12 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // FileDelete deletes a file from Azure OpenAI by trying each key until successful. func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2311,11 +1951,6 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2358,8 +1993,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2375,9 +2009,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2387,7 +2019,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2410,9 +2042,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2432,24 +2062,17 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // FileContent downloads file content from Azure OpenAI by trying each key until found. func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation") } var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2491,8 +2114,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2504,7 +2126,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2524,9 +2146,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2537,12 +2157,6 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // BatchCreate creates a new batch job on Azure OpenAI. // Azure Batch API uses the same format as OpenAI but with Azure-specific URL patterns. func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -2550,12 +2164,11 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Convert inline requests to JSONL format jsonlData, err := openai.ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" uploadResp, bifrostErr := provider.FileUpload(ctx, key, &schemas.BifrostFileUploadRequest{ - Provider: schemas.Azure, File: jsonlData, Filename: "batch_requests.jsonl", Purpose: "batch", @@ -2569,7 +2182,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil) } // Get API version @@ -2616,7 +2229,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -2629,13 +2242,12 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } var openAIResp openai.OpenAIBatchResponse @@ -2646,25 +2258,24 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs from all provided Azure keys and aggregates results. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation") } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2675,18 +2286,9 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2732,13 +2334,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIBatchListResponse @@ -2751,7 +2352,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -2764,9 +2365,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2778,14 +2377,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // BatchRetrieve retrieves a specific batch job from Azure OpenAI by trying each key until found. func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2793,11 +2390,6 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2840,8 +2432,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2853,7 +2444,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2871,8 +2462,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -2881,14 +2471,12 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // BatchCancel cancels a batch job on Azure OpenAI by trying each key until successful. func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2896,11 +2484,6 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2943,8 +2526,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2956,7 +2538,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2981,9 +2563,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3017,8 +2597,6 @@ func (provider *AzureProvider) BatchDelete(ctx *schemas.BifrostContext, keys []s // BatchResults retrieves batch results from Azure OpenAI by trying each key until successful. // For Azure (like OpenAI), batch results are obtained by downloading the output_file_id. func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3029,7 +2607,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file content (using all keys) @@ -3058,9 +2636,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } diff --git a/core/providers/azure/azure_test.go b/core/providers/azure/azure_test.go index 41b32077e7..9caf87ebc7 100644 --- a/core/providers/azure/azure_test.go +++ b/core/providers/azure/azure_test.go @@ -29,7 +29,7 @@ func TestAzure(t *testing.T) { ChatModel: "gpt-4o-backup", PromptCachingModel: "gpt-4o-backup", VisionModel: "gpt-4o", - ChatAudioModel: "gpt-4o-mini-audio-preview", + ChatAudioModel: "gpt-4o-mini-audio-preview", Fallbacks: []schemas.Fallback{ {Provider: schemas.Azure, Model: "gpt-4o-backup"}, }, @@ -42,42 +42,42 @@ func TestAzure(t *testing.T) { ImageEditModel: "gpt-image-1", VideoGenerationModel: "sora-2", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - Embedding: true, - ListModels: true, - Reasoning: true, - ChatAudio: true, - Transcription: false, // Disabled for azure because of 3 calls/minute quota - TranscriptionStream: false, // Not properly supported yet by Azure - SpeechSynthesis: false, // Disabled for azure because of 3 calls/minute quota - SpeechSynthesisStream: false, // Disabled for azure because of 3 calls/minute quota - StructuredOutputs: true, // Structured outputs with nullable enum support - PromptCaching: true, - ImageGeneration: false, // Skipped for Azure - ImageGenerationStream: false, // Skipped for Azure - ImageEdit: false, // Model not deployed on Azure endpoint - ImageEditStream: false, // Model not deployed on Azure endpoint - ImageVariation: false, // Not supported by Azure - VideoGeneration: false, // disabled for now because of long running operations - VideoDownload: false, - VideoRetrieve: false, - VideoRemix: false, - VideoList: false, - VideoDelete: false, - InterleavedThinking: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + Embedding: true, + ListModels: true, + Reasoning: true, + ChatAudio: true, + Transcription: false, // Disabled for azure because of 3 calls/minute quota + TranscriptionStream: false, // Not properly supported yet by Azure + SpeechSynthesis: false, // Disabled for azure because of 3 calls/minute quota + SpeechSynthesisStream: false, // Disabled for azure because of 3 calls/minute quota + StructuredOutputs: true, // Structured outputs with nullable enum support + PromptCaching: true, + ImageGeneration: false, // Skipped for Azure + ImageGenerationStream: false, // Skipped for Azure + ImageEdit: false, // Model not deployed on Azure endpoint + ImageEditStream: false, // Model not deployed on Azure endpoint + ImageVariation: false, // Not supported by Azure + VideoGeneration: false, // disabled for now because of long running operations + VideoDownload: false, + VideoRetrieve: false, + VideoRemix: false, + VideoList: false, + VideoDelete: false, + InterleavedThinking: true, }, DisableParallelFor: []string{"Transcription"}, // Azure Whisper has 3 calls/minute quota } diff --git a/core/providers/azure/files.go b/core/providers/azure/files.go index 4c7ce174f8..d008b146de 100644 --- a/core/providers/azure/files.go +++ b/core/providers/azure/files.go @@ -24,7 +24,7 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -33,11 +33,11 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R Scopes: scopes, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("Azure access token is empty", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("azure access token is empty", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -68,16 +68,16 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -110,9 +110,7 @@ func (r *AzureFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mod StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index d5ff81229a..5daca3836d 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -1,65 +1,13 @@ package azure import ( - "slices" + "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// findMatchingAllowedModel finds a matching item in a slice, considering both -// exact match and base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact match - if slices.Contains(slice, value) { - return value - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Return the item from slice (not value) to use the actual name from allowedModels - for _, item := range slice { - if schemas.SameBaseModel(item, value) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and base model matches (ignoring version suffixes). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check exact match first (by alias/key) - if deployment, ok := deployments[modelID]; ok { - return deployment, modelID - } - - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == modelID { - return depValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - for aliasKey, deploymentValue := range deployments { - // Check if modelID's base matches deploymentValue's base - if schemas.SameBaseModel(deploymentValue, modelID) { - return deploymentValue, aliasKey - } - // Also check if modelID's base matches alias's base (for cases where alias is used as deployment) - if schemas.SameBaseModel(aliasKey, modelID) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -68,111 +16,36 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - modelID := model.ID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering base model matches (ignoring version suffixes) - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ID, modelID, deploymentAlias, matchedAllowedModel) { - continue - } - - modelEntry := schemas.Model{ - ID: string(schemas.Azure) + "/" + modelID, - Created: schemas.Ptr(model.CreatedAt), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Azure) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true - } else { - includedModels[modelID] = true - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Azure, + MatchFns: providerUtils.DefaultMatchFns(), } - - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[alias] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[alias] = true - } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Azure) + "/" + result.ResolvedID, + Created: schemas.Ptr(model.CreatedAt), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go index 49d1db8de3..20f216c19e 100644 --- a/core/providers/azure/utils.go +++ b/core/providers/azure/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader β€” skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -27,24 +27,24 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Replace model with deployment jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Delete fallbacks field jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { @@ -52,10 +52,10 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s request.Model = deployment reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } if isStreaming { @@ -68,7 +68,7 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes, preserving field order jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index b6fe9504d8..0dc798bc0f 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -26,7 +26,6 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/providers/anthropic" - "github.com/maximhq/bifrost/core/providers/cohere" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -222,7 +221,7 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -245,10 +244,10 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -349,7 +348,7 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros if key.Value.GetValue() != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -370,10 +369,10 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros } var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } var opErr *net.OpError var dnsErr *net.DNSError @@ -420,15 +419,9 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros // makeStreamingRequest creates a streaming request to Bedrock's API. // It formats the request, sends it to Bedrock, and returns the response. // Returns the response body and an error if the request fails. -func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, "", providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - +func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, *schemas.BifrostError) { // Format the path with proper model identifier for streaming - path, deployment := provider.getModelPath(action, model, key) + path := provider.getModelPath(action, model, key) region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { @@ -438,7 +431,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Create HTTP request for streaming req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonData)) if reqErr != nil { - return nil, deployment, providerUtils.NewBifrostOperationError("error creating request", reqErr, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", reqErr) } // Set any extra headers from network config @@ -457,8 +450,8 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex } else { req.Header.Set("Accept", "application/vnd.amazon.eventstream") // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { - return nil, deployment, err + if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { + return nil, err } } @@ -466,7 +459,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex resp, respErr := provider.client.Do(req) if respErr != nil { if errors.Is(respErr, context.Canceled) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Type: schemas.Ptr(schemas.RequestCancelled), @@ -478,35 +471,29 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(respErr, &netErr) && netErr.Timeout() { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.DeadlineExceeded) { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError var dnsErr *net.DNSError if errors.As(respErr, &opErr) || errors.As(respErr, &dnsErr) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderNetworkError, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderDoRequest, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } @@ -517,10 +504,10 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, deployment, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) + return nil, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) } - return resp, deployment, nil + return resp, nil } // signAWSRequest signs an HTTP request using AWS Signature Version 4. @@ -537,7 +524,6 @@ func signAWSRequest( externalID *schemas.EnvVar, sessionName *schemas.EnvVar, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Set required headers before signing (only if not already set) if req.Header.Get("Content-Type") == "" { @@ -552,7 +538,7 @@ func signAWSRequest( if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return providerUtils.NewBifrostOperationError("error reading request body", err, providerName) + return providerUtils.NewBifrostOperationError("error reading request body", err) } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -594,11 +580,10 @@ func signAWSRequest( ) } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } if roleARN != nil && roleARN.GetValue() != "" { - extID := "" if externalID != nil { extID = externalID.GetValue() @@ -653,12 +638,12 @@ func signAWSRequest( // Get credentials creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } // Sign the request with AWS Signature V4 if err := signer.SignHTTP(ctx, creds, req, bodyHash, service, region, time.Now()); err != nil { - return providerUtils.NewBifrostOperationError("failed to sign request", err, providerName) + return providerUtils.NewBifrostOperationError("failed to sign request", err) } return nil @@ -668,13 +653,7 @@ func signAWSRequest( // It retrieves all foundation models available in Amazon Bedrock for a specific key. func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - config := key.BedrockKeyConfig - region := DefaultBedrockRegion if config.Region != nil && config.Region.GetValue() != "" { region = config.Region.GetValue() @@ -721,7 +700,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, err } } @@ -744,10 +723,10 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -795,9 +774,9 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Convert to Bifrost response - response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, key.BlacklistedModels, request.Unfiltered) + response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() @@ -838,24 +817,17 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) body, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -867,29 +839,25 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key // Handle model-specific response conversion var bifrostResponse *schemas.BifrostTextCompletionResponse switch { - case schemas.IsAnthropicModel(deployment): + case schemas.IsAnthropicModel(request.Model): var response BedrockAnthropicTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() - case schemas.IsMistralModel(deployment): + case schemas.IsMistralModel(request.Model): var response BedrockMistralTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() default: - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model)) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -902,7 +870,7 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err) } bifrostResponse.ExtraFields.RawResponse = rawResponse } @@ -920,22 +888,17 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -951,9 +914,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -999,14 +962,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1037,15 +995,10 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: fmt.Sprintf("%s stream %s: %s", providerName, excType, errMsg), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1057,18 +1010,14 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex } if err := sonic.Unmarshal(message.Payload, &chunkPayload); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } // Create BifrostStreamChunk response containing the raw model-specific JSON chunk textResponse := &schemas.BifrostTextCompletionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - Latency: time.Since(startTime).Milliseconds(), + Latency: time.Since(startTime).Milliseconds(), // Pass the raw JSON string from the chunk bytes RawResponse: string(chunkPayload.Bytes), }, @@ -1090,26 +1039,19 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1126,13 +1068,13 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostChatResponse(ctx, request.Model) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Override finish reason for structured output @@ -1146,10 +1088,6 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1176,21 +1114,17 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } - - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1201,14 +1135,13 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds) - // Start streaming in a goroutine go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1264,7 +1197,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) + provider.logger.Warn("Error decoding EventStream message: %v", err) // Transport-level errors (stale/closed connection, unexpected EOF) are retryable. // Use IsBifrostError:false so the retry gate in executeRequestWithRetries can retry. if isStreamTransportError(err) { @@ -1274,14 +1207,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1297,7 +1225,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) + err := fmt.Errorf("stream %s: %s", excType, errMsg) // Retryable AWS exceptions must not set IsBifrostError:true β€” that would // bypass the retry gate in executeRequestWithRetries. Instead emit // IsBifrostError:false with the equivalent HTTP status code so the existing @@ -1309,14 +1237,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1326,7 +1249,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1405,12 +1328,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1427,11 +1346,6 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response, bifrostErr, _ := streamEvent.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1440,12 +1354,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response.ID = id response.Model = request.Model response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1464,8 +1374,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } // Send final response - response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) - response.ExtraFields.ModelDeployment = deployment + response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, request.Model, 0) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonData) @@ -1486,26 +1395,19 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1522,22 +1424,18 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostResponsesResponse(ctx) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = deployment + bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1565,20 +1463,17 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1594,9 +1489,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1617,7 +1512,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Create stream state for stateful conversions streamState := acquireBedrockResponsesStreamState() - streamState.Model = &deployment + streamState.Model = &request.Model defer releaseBedrockResponsesStreamState(streamState) // Check for structured output mode - if set, we need to intercept tool calls @@ -1633,7 +1528,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po lastChunkTime := startTime decoder := eventstream.NewDecoder() payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer - for { // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { @@ -1651,12 +1545,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) for i, finalResponse := range finalResponses { finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1679,7 +1569,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) + provider.logger.Warn("Error decoding EventStream message: %v", err) // Transport-level errors (stale/closed connection, unexpected EOF) are retryable. // Use IsBifrostError:false so the retry gate in executeRequestWithRetries can retry. if isStreamTransportError(err) { @@ -1689,14 +1579,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1712,7 +1597,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) + err := fmt.Errorf("stream %s: %s", excType, errMsg) // Retryable AWS exceptions must not set IsBifrostError:true β€” that would // bypass the retry gate in executeRequestWithRetries. Instead emit // IsBifrostError:false with the equivalent HTTP status code so the existing @@ -1724,14 +1609,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1741,7 +1621,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1797,12 +1677,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po SequenceNumber: chunkIndex, Delta: &content, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1819,11 +1695,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1831,12 +1702,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po for _, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1862,15 +1729,10 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Determine model type modelType, err := DetermineEmbeddingModelType(request.Model) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } // Convert request and execute based on model type @@ -1879,7 +1741,6 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string var jsonData []byte switch modelType { @@ -1889,12 +1750,11 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTitanEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) case "cohere": @@ -1903,16 +1763,15 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockCohereEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) default: - return nil, providerUtils.NewConfigurationError("unsupported embedding model type", providerName) + return nil, providerUtils.NewConfigurationError("unsupported embedding model type") } if providerResponseHeaders != nil { @@ -1921,32 +1780,40 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche if bifrostError != nil { return nil, providerUtils.EnrichError(ctx, bifrostError, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - // Parse response based on model type var bifrostResponse *schemas.BifrostEmbeddingResponse switch modelType { case "titan": var titanResp BedrockTitanEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = titanResp.ToBifrostEmbeddingResponse() bifrostResponse.Model = request.Model case "cohere": - var cohereResp cohere.CohereEmbeddingResponse + var cohereResp BedrockCohereEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + converted, convErr := cohereResp.ToBifrostEmbeddingResponse() + if convErr != nil { + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse = cohereResp.ToBifrostEmbeddingResponse() + bifrostResponse = converted bifrostResponse.Model = request.Model + // For embeddings_by_type responses preserve the raw Bedrock payload so the + // invoke-endpoint converter can return all encoding variants verbatim, since + // the internal BifrostEmbeddingResponse only has float32 and string fields. + if cohereResp.ResponseType == "embeddings_by_type" { + var rawResponseData interface{} + if err := sonic.Unmarshal(rawResponse, &rawResponseData); err == nil { + bifrostResponse.ExtraFields.RawResponse = rawResponseData + } + } } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1972,26 +1839,16 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - - deployment := strings.TrimSpace(resolveBedrockDeployment(request.Model, key)) - if deployment == "" { - return nil, providerUtils.NewConfigurationError("bedrock rerank model is empty", providerName) - } - if !strings.HasPrefix(deployment, "arn:") { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", deployment), providerName) + if !strings.HasPrefix(request.Model, "arn:") { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", request.Model)) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToBedrockRerankRequest(request, deployment) + return ToBedrockRerankRequest(request, request.Model) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2015,10 +1872,6 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas bifrostResponse := response.ToBifrostRerankResponse(request.Documents, returnDocuments) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2048,37 +1901,34 @@ func (provider *BedrockProvider) TranscriptionStream(ctx *schemas.BifrostContext } // ImageGeneration generates images using Amazon Bedrock. -// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2. +// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2, and Stability AI models. // Returns a BifrostImageGenerationResponse containing the generated images and any error that occurred. func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var rawResponse []byte var jsonData []byte var bifrostError *schemas.BifrostError var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string + + path = provider.getModelPath("invoke", request.Model, key) jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageGenerationRequest(request) + } return ToBedrockImageGenerationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2091,19 +1941,15 @@ func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, ke var bifrostResponse *schemas.BifrostImageGenerationResponse var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2128,33 +1974,36 @@ func (provider *BedrockProvider) ImageGenerationStream(ctx *schemas.BifrostConte } // ImageEdit performs image editing using Amazon Bedrock. -// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2. -// Supports three edit types: INPAINTING, OUTPAINTING, and BACKGROUND_REMOVAL. +// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2 (three edit types: +// INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL), and Stability AI edit models (inpaint, outpaint, +// recolor, search-replace, erase-object, remove-bg, control-sketch, control-structure, style-guide, +// style-transfer, upscale-creative, upscale-conservative, upscale-fast). // Returns a BifrostImageGenerationResponse containing the edited images and any error that occurred. func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageEditRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError + // Stability AI routing and task-type inference use the actual model ID. + path := provider.getModelPath("invoke", request.Model, key) + jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageEditRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageEditRequest(request, request.Model) + } + return ToBedrockImageEditRequest(request) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2166,20 +2015,16 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche // Parse response (reuse BedrockImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2211,11 +2056,6 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError @@ -2224,14 +2064,13 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageVariationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2243,20 +2082,16 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key // Parse response (reuse BedrockImageGenerationResponse and ToBifrostImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageVariationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2315,13 +2150,6 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is is missing in file upload request") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2343,7 +2171,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch if s3Bucket == "" { provider.logger.Error("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)") - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } // Parse bucket name and optional prefix from s3Bucket (could be "bucket-name" or "s3://bucket-name/prefix/") @@ -2377,14 +2205,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch httpReq, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(request.File)) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } httpReq.Header.Set("Content-Type", "application/octet-stream") httpReq.ContentLength = int64(len(request.File)) // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("error signing request: %s", err.Error.Message) return nil, err } @@ -2404,14 +2232,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) provider.logger.Error("s3 upload failed: %d", resp.StatusCode) - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Return S3 URI as the file ID @@ -2428,9 +2256,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch StorageBackend: schemas.FileStorageS3, StorageURI: s3URI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2443,8 +2269,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2466,7 +2290,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc } if s3Bucket == "" { - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } bucketName, bucketPrefix := parseS3URI(s3Bucket) @@ -2477,7 +2301,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2488,10 +2312,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2518,14 +2338,11 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request for S3 - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); bifrostErr != nil { return nil, bifrostErr } @@ -2544,23 +2361,23 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Parse S3 ListObjectsV2 XML response var listResp S3ListObjectsResponse if err := parseS3ListResponse(body, &listResp); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err) } // Convert files to Bifrost format @@ -2592,9 +2409,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2610,25 +2425,18 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2640,12 +2448,12 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodHead, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2665,13 +2473,13 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, nil, nil) continue } @@ -2701,9 +2509,7 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys StorageBackend: schemas.FileStorageS3, StorageURI: request.FileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2717,25 +2523,18 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2747,12 +2546,12 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] httpReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2772,7 +2571,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } @@ -2780,7 +2579,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } @@ -2791,9 +2590,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2807,25 +2604,18 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2837,12 +2627,12 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2862,21 +2652,21 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err) continue } @@ -2890,9 +2680,7 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ Content: body, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2907,13 +2695,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is not provided") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Require RoleArn in extra params roleArn := "" // First we will honor the role_arn coming from the client side if present @@ -2924,14 +2705,14 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // If its empty then we will honor the role_arn from the key config if roleArn == "" { - if key.BedrockKeyConfig.ARN != nil { - roleArn = key.BedrockKeyConfig.ARN.GetValue() + if key.BedrockKeyConfig.RoleARN != nil { + roleArn = key.BedrockKeyConfig.RoleARN.GetValue() } } // And if still we don't get role ARN if roleArn == "" { provider.logger.Error("role_arn is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil) } // Get output S3 URI from extra params outputS3Uri := "" @@ -2942,24 +2723,12 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } if outputS3Uri == "" { provider.logger.Error("output_s3_uri is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil) } if request.Model == nil { provider.logger.Error("model is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil, providerName) - } - - // Get model ID - - var modelID *string - if key.BedrockKeyConfig.Deployments != nil && request.Model != nil { - if deployment, ok := key.BedrockKeyConfig.Deployments[*request.Model]; ok { - modelID = schemas.Ptr(deployment) - } - } - if modelID == nil { - modelID = request.Model + return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil) } // Generate job name @@ -2987,9 +2756,9 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // Convert inline requests to Bedrock JSONL format - jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, modelID) + jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, request.Model) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Generate S3 key for the input file @@ -3009,7 +2778,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc bucket, s3Key, jsonlData, - providerName, ); bifrostErr != nil { return nil, bifrostErr } @@ -3020,13 +2788,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc // Validate that we have an input file ID (either provided or uploaded) if inputFileID == "" { provider.logger.Error("either input_file_id (S3 URI) or requests array is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil) } // Build request bedrockReq := &BedrockBatchJobRequest{ JobName: jobName, - ModelID: modelID, + ModelID: request.Model, RoleArn: roleArn, InputDataConfig: BedrockInputDataConfig{ S3InputDataConfig: BedrockS3InputDataConfig{ @@ -3051,7 +2819,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(bedrockReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } sendBackRawRequest := provider.sendBackRawRequest @@ -3066,11 +2834,11 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc reqURL := fmt.Sprintf("https://bedrock.%s.amazonaws.com/model-invocation-job", region) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewBuffer(jsonData)) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } @@ -3089,13 +2857,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc }, }, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { @@ -3104,7 +2872,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, sendBackRawRequest, sendBackRawResponse) } // AWS CreateModelInvocationJob only returns jobArn, not status or other details. @@ -3121,9 +2889,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc InputFileID: inputFileID, Status: schemas.BatchStatusValidating, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3136,9 +2902,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc Status: retrieveResp.Status, CreatedAt: retrieveResp.CreatedAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3156,12 +2920,10 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - // Initialize serial pagination helper (Bedrock uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3172,17 +2934,9 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3205,11 +2959,11 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); bifrostErr != nil { return nil, bifrostErr } @@ -3228,13 +2982,13 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { @@ -3243,7 +2997,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s var bedrockResp BedrockBatchJobListResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert batches to Bifrost format @@ -3288,9 +3042,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3330,7 +3082,7 @@ func (provider *BedrockProvider) fetchBatchManifest(ctx *schemas.BifrostContext, } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("failed to sign manifest request: %v", err) return nil } @@ -3368,19 +3120,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3392,12 +3137,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3417,14 +3162,14 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3435,7 +3180,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) continue } @@ -3454,9 +3199,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys Status: ToBifrostBatchStatus(bedrockResp.Status), Metadata: metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3512,19 +3255,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3536,12 +3272,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3561,14 +3297,14 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3591,9 +3327,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: schemas.BatchStatusCancelling, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: totalLatency.Milliseconds(), + Latency: totalLatency.Milliseconds(), }, }, nil } @@ -3603,9 +3337,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: retrieveResp.Status, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3626,8 +3358,6 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output S3 URI prefix (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3638,7 +3368,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil) } outputS3URI := *batchResp.OutputFileID @@ -3680,7 +3410,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys if directErr != nil { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to access batch results at %s: listing failed and direct access failed", outputS3URI), - nil, providerName) + nil) } // Direct download succeeded, parse the content @@ -3689,9 +3419,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } if len(parseErrors) > 0 { @@ -3724,9 +3452,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: allResults, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: totalLatency, + Latency: totalLatency, }, } @@ -3737,26 +3463,14 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return batchResultsResp, nil } -func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) (string, string) { - deployment := resolveBedrockDeployment(model, key) - // Default: use model/deployment directly - path := fmt.Sprintf("%s/%s", deployment, basePath) +func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) string { + path := fmt.Sprintf("%s/%s", model, basePath) // If ARN is present, Bedrock expects the ARN-scoped identifier if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.ARN != nil && key.BedrockKeyConfig.ARN.GetValue() != "" { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), deployment)) + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), model)) path = fmt.Sprintf("%s/%s", encodedModelIdentifier, basePath) } - return path, deployment -} - -func resolveBedrockDeployment(model string, key schemas.Key) string { - deployment := model - if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.Deployments != nil { - if mapped, ok := key.BedrockKeyConfig.Deployments[model]; ok && mapped != "" { - deployment = mapped - } - } - return deployment + return path } func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { @@ -3764,16 +3478,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Convert to Bedrock Converse format using the existing responses converter converseReq, convErr := ToBedrockResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr) } // Wrap in the CountTokens request envelope @@ -3782,11 +3490,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(countTokensReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Format the path with proper model identifier - path, deployment := provider.getModelPath("count-tokens", request.Model, key) + path := provider.getModelPath("count-tokens", request.Model, key) // Send the request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -3797,15 +3505,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc if isCountTokensUnsupported(bifrostErr) { estimated := estimateTokenCount(jsonData) return &schemas.BifrostCountTokensResponse{ - Model: deployment, + Model: request.Model, InputTokens: estimated, TotalTokens: &estimated, Object: "response.input_tokens", ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - ModelRequested: request.Model, - ModelDeployment: deployment, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3828,15 +3532,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc } // Convert to Bifrost format - response := bedrockResponse.ToBifrostCountTokensResponse(deployment) + response := bedrockResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequest } diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index 1949051c44..95a3b00982 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -175,53 +175,53 @@ func TestBedrock(t *testing.T) { {Provider: schemas.Bedrock, Model: "claude-4-sonnet"}, {Provider: schemas.Bedrock, Model: "claude-4.5-sonnet"}, }, - EmbeddingModel: "cohere.embed-v4:0", - RerankModel: rerankModelARN, - ReasoningModel: "claude-4.5-sonnet", - PromptCachingModel: "claude-4.5-sonnet", - ImageEditModel: "amazon.nova-canvas-v1:0", - ImageVariationModel: "amazon.nova-canvas-v1:0", + EmbeddingModel: "cohere.embed-v4:0", + RerankModel: rerankModelARN, + ReasoningModel: "claude-4.5-sonnet", + PromptCachingModel: "claude-4.5-sonnet", + ImageEditModel: "amazon.nova-canvas-v1:0", + ImageVariationModel: "amazon.nova-canvas-v1:0", InterleavedThinkingModel: "global.anthropic.claude-opus-4-5-20251101-v1:0", - BatchExtraParams: batchExtraParams, - FileExtraParams: fileExtraParams, + BatchExtraParams: batchExtraParams, + FileExtraParams: fileExtraParams, Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, // Bedrock doesn't support image URL - ImageBase64: true, - MultipleImages: false, // Since one of the image is URL - FileBase64: true, - FileURL: false, // S3 urls supported for nova models - CompleteEnd2End: true, - Embedding: true, - Rerank: rerankModelARN != "", - ListModels: true, - Reasoning: true, - PromptCaching: true, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: true, - FileBatchInput: true, - CountTokens: true, - ImageEdit: true, - ImageVariation: true, - StructuredOutputs: true, - InterleavedThinking: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, // Bedrock doesn't support image URL + ImageBase64: true, + MultipleImages: false, // Since one of the image is URL + FileBase64: true, + FileURL: false, // S3 urls supported for nova models + CompleteEnd2End: true, + Embedding: true, + Rerank: rerankModelARN != "", + ListModels: true, + Reasoning: true, + PromptCaching: true, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: true, + FileBatchInput: true, + CountTokens: true, + ImageEdit: true, + ImageVariation: true, + StructuredOutputs: true, + InterleavedThinking: true, }, } @@ -1256,7 +1256,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: "tool-use-123", Name: "get_weather", - Input: json.RawMessage(`{"location":"NYC"}`), + Input: json.RawMessage(`{"location":"NYC"}`), }, }, }, @@ -1331,7 +1331,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: "tool-use-456", Name: "calculate", - Input: json.RawMessage(`{"expression":"2+2"}`), + Input: json.RawMessage(`{"expression":"2+2"}`), }, }, }, @@ -1860,7 +1860,7 @@ func TestBifrostToBedrockResponseConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: "call-111", Name: "get_weather", - Input: json.RawMessage(`{"location":"NYC"}`), + Input: json.RawMessage(`{"location":"NYC"}`), }, }, { @@ -2248,7 +2248,7 @@ func TestToolResultJSONParsingResponsesAPI(t *testing.T) { name: "JSONObjectResult", toolResultContent: `{"location":"NYC","temperature":72}`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"location": "NYC", "temperature": float64(72)}), + expectedJSON: mustMarshalJSON(map[string]any{"location": "NYC", "temperature": float64(72)}), }, { name: "JSONArrayResult", @@ -2265,37 +2265,37 @@ func TestToolResultJSONParsingResponsesAPI(t *testing.T) { name: "JSONPrimitiveNumberResult", toolResultContent: `42`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": float64(42)}), + expectedJSON: mustMarshalJSON(map[string]any{"value": float64(42)}), }, { name: "JSONPrimitiveStringResult", toolResultContent: `"hello world"`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": "hello world"}), + expectedJSON: mustMarshalJSON(map[string]any{"value": "hello world"}), }, { name: "JSONPrimitiveBooleanResult", toolResultContent: `true`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": true}), + expectedJSON: mustMarshalJSON(map[string]any{"value": true}), }, { name: "JSONPrimitiveNullResult", toolResultContent: `null`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": nil}), + expectedJSON: mustMarshalJSON(map[string]any{"value": nil}), }, { name: "EmptyJSONObjectResult", toolResultContent: `{}`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{}), + expectedJSON: mustMarshalJSON(map[string]any{}), }, { name: "EmptyJSONArrayResult", toolResultContent: `[]`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"results": []any{}}), + expectedJSON: mustMarshalJSON(map[string]any{"results": []any{}}), }, } @@ -3626,17 +3626,17 @@ func TestToBedrockInvokeMessagesStreamResponse_NoDuplicateContentBlockStop(t *te { Type: schemas.ResponsesStreamResponseTypeOutputTextDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeContentPartDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeOutputItemDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, } diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index 71e7890935..6459df377b 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -247,8 +247,6 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte Usage: usage, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Bedrock, }, } diff --git a/core/providers/bedrock/embedding.go b/core/providers/bedrock/embedding.go index d9981a1e4d..2e2875e9cf 100644 --- a/core/providers/bedrock/embedding.go +++ b/core/providers/bedrock/embedding.go @@ -1,10 +1,10 @@ package bedrock import ( + "encoding/json" "fmt" "strings" - "github.com/maximhq/bifrost/core/providers/cohere" "github.com/maximhq/bifrost/core/schemas" ) @@ -19,11 +19,6 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) return nil, fmt.Errorf("no input text provided for embedding") } - // Validate dimensions parameter - Titan models do not support it - if bifrostReq.Params != nil && bifrostReq.Params.Dimensions != nil { - return nil, fmt.Errorf("amazon Titan embedding models do not support custom dimensions parameter") - } - titanReq := &BedrockTitanEmbeddingRequest{} // Set input text @@ -36,8 +31,26 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) } titanReq.InputText = embeddingText } + if bifrostReq.Params != nil { - titanReq.ExtraParams = bifrostReq.Params.ExtraParams + titanReq.Dimensions = bifrostReq.Params.Dimensions + if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok { + if b, ok := normalize.(bool); ok { + titanReq.Normalize = &b + } + } + // Forward remaining extra params (excluding normalize which is now a first-class field) + if len(bifrostReq.Params.ExtraParams) > 0 { + extra := make(map[string]interface{}) + for k, v := range bifrostReq.Params.ExtraParams { + if k != "normalize" { + extra[k] = v + } + } + if len(extra) > 0 { + titanReq.ExtraParams = extra + } + } } return titanReq, nil @@ -69,20 +82,81 @@ func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *sch return bifrostResponse } -// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format -// Reuses the Cohere converter since the format is identical -func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*cohere.CohereEmbeddingRequest, error) { +// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format. +// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body. +func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost embedding request is nil") } + if bifrostReq.Input == nil { + return nil, fmt.Errorf("no input provided for embedding") + } - // Reuse Cohere's converter - the format is identical for Bedrock - cohereReq := cohere.ToCohereEmbeddingRequest(bifrostReq) - if cohereReq == nil { - return nil, fmt.Errorf("failed to convert to Cohere embedding request") + req := &BedrockCohereEmbeddingRequest{} + + // Map texts + if bifrostReq.Input.Text != nil { + req.Texts = []string{*bifrostReq.Input.Text} + } else if len(bifrostReq.Input.Texts) > 0 { + req.Texts = bifrostReq.Input.Texts } - return cohereReq, nil + if bifrostReq.Params != nil { + extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams)) + for k, v := range bifrostReq.Params.ExtraParams { + extra[k] = v + } + + if v, ok := extra["input_type"]; ok { + if s, ok := v.(string); ok { + req.InputType = s + delete(extra, "input_type") + } + } + if v, ok := extra["truncate"]; ok { + if s, ok := v.(string); ok { + req.Truncate = &s + delete(extra, "truncate") + } + } + if v, ok := extra["embedding_types"]; ok { + if ss, ok := v.([]string); ok { + req.EmbeddingTypes = ss + delete(extra, "embedding_types") + } + } + if v, ok := extra["images"]; ok { + if ss, ok := v.([]string); ok { + req.Images = ss + delete(extra, "images") + } + } + if v, ok := extra["inputs"]; ok { + if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok { + req.Inputs = inputs + delete(extra, "inputs") + } + } + if v, ok := extra["max_tokens"]; ok { + switch n := v.(type) { + case int: + req.MaxTokens = &n + delete(extra, "max_tokens") + case float64: + i := int(n) + req.MaxTokens = &i + delete(extra, "max_tokens") + } + } + if bifrostReq.Params.Dimensions != nil { + req.OutputDimension = bifrostReq.Params.Dimensions + } + if len(extra) > 0 { + req.ExtraParams = extra + } + } + + return req, nil } // DetermineEmbeddingModelType determines the embedding model type from the model name @@ -96,3 +170,102 @@ func DetermineEmbeddingModelType(model string) (string, error) { return "", fmt.Errorf("unsupported embedding model: %s", model) } } + +// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format. +// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats" +// (the default, when no embedding_types are requested), and as a typed object when +// response_type is "embeddings_by_type". +func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) { + if r == nil { + return nil, fmt.Errorf("nil Bedrock Cohere embedding response") + } + + bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"} + + switch r.ResponseType { + case "embeddings_by_type": + // Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]} + var typed struct { + Float [][]float32 `json:"float"` + Base64 []string `json:"base64"` + Int8 [][]int8 `json:"int8"` + Uint8 [][]int32 `json:"uint8"` // int32 avoids []byteβ†’base64 JSON issue + Binary [][]int8 `json:"binary"` + Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byteβ†’base64 JSON issue + } + if err := json.Unmarshal(r.Embeddings, &typed); err != nil { + return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err) + } + if typed.Float != nil { + for i, emb := range typed.Float { + float64Emb := make([]float64, len(emb)) + for j, v := range emb { + float64Emb[j] = float64(v) + } + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + }) + } + } + if typed.Base64 != nil { + for i, emb := range typed.Base64 { + e := emb + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e}, + }) + } + } + for i, emb := range typed.Int8 { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, + }) + } + for i, emb := range typed.Binary { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, + }) + } + for i, emb := range typed.Uint8 { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, + }) + } + for i, emb := range typed.Ubinary { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, + }) + } + + default: + // Default / "embeddings_floats": raw array form [[...], [...]] + var floats [][]float32 + if err := json.Unmarshal(r.Embeddings, &floats); err != nil { + return nil, fmt.Errorf("error parsing embeddings_floats: %w", err) + } + for i, emb := range floats { + float64Emb := make([]float64, len(emb)) + for j, v := range emb { + float64Emb[j] = float64(v) + } + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + }) + } + } + + return bifrostResponse, nil +} diff --git a/core/providers/bedrock/images.go b/core/providers/bedrock/images.go index 8c9ba9569c..b0ac35dc01 100644 --- a/core/providers/bedrock/images.go +++ b/core/providers/bedrock/images.go @@ -34,6 +34,61 @@ func mapQualityToBedrock(quality *string) *string { } } +// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.") +func isStabilityAIModel(model string) bool { + return strings.Contains(strings.ToLower(model), "stability.") +} + +// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat +// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models. +// Stability AI is excluded here β€” it's handled separately because it also supports image edit. +func isPromptOnlyImageGenerationModel(model string) bool { + m := strings.ToLower(model) + return strings.Contains(m, "image") +} + +// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI +// flat request format used by Bedrock (stability.stable-image-* models). +func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) { + if request == nil { + return nil, fmt.Errorf("request is nil") + } + if request.Input == nil { + return nil, fmt.Errorf("request input is required") + } + + req := &StabilityAIImageGenerationRequest{ + Prompt: request.Input.Prompt, + } + + if request.Params != nil { + if request.Params.AspectRatio != nil { + req.AspectRatio = request.Params.AspectRatio + } + if request.Params.OutputFormat != nil { + req.OutputFormat = request.Params.OutputFormat + } + if request.Params.Seed != nil { + req.Seed = request.Params.Seed + } + if request.Params.NegativePrompt != nil { + req.NegativePrompt = request.Params.NegativePrompt + } + if request.Params.ExtraParams != nil { + // aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set + if req.AspectRatio == nil { + if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok { + delete(request.Params.ExtraParams, "aspect_ratio") + req.AspectRatio = ar + } + } + req.ExtraParams = request.Params.ExtraParams + } + } + + return req, nil +} + // ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) { if request == nil { @@ -41,7 +96,7 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } if request.Input == nil { - return nil, fmt.Errorf("request.Input is required") + return nil, fmt.Errorf("request input is required") } bedrockReq := &BedrockImageGenerationRequest{ @@ -101,6 +156,24 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } +// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to +// the native Bedrock invoke API response format used by Stability AI models. +// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas. +func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) { + if response == nil { + return nil, fmt.Errorf("response is nil") + } + result := &BedrockImageGenerationResponse{} + for _, d := range response.Data { + result.Images = append(result.Images, d.B64JSON) + } + if response.ImageGenerationResponseParameters != nil { + result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons + result.Seeds = response.ImageGenerationResponseParameters.Seeds + } + return result, nil +} + // ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) { if request == nil { @@ -358,6 +431,292 @@ func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGener return config } +// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value +// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type. +func getStabilityAITaskTypeFromParams(t string) string { + switch strings.ToLower(t) { + case "inpainting", "inpaint": + return "inpaint" + case "outpainting", "outpaint": + return "outpaint" + case "background_removal", "remove_background": + return "remove-bg" + case "erase_object": + return "erase-object" + case "upscale_fast": + return "upscale-fast" + case "upscale_creative": + return "upscale-creative" + case "upscale_conservative": + return "upscale-conservative" + case "recolor": + return "recolor" + case "search_replace": + return "search-replace" + case "control_sketch": + return "control-sketch" + case "control_structure": + return "control-structure" + case "style_guide": + return "style-guide" + case "style_transfer": + return "style-transfer" + default: + return "" + } +} + +// getStabilityAIEditTaskType infers the Stability AI edit task from the model name. +// Returns an error if the model name does not match any known pattern. +func getStabilityAIEditTaskType(model string) (string, error) { + m := strings.ToLower(model) + switch { + case strings.Contains(m, "stable-creative-upscale"): + return "upscale-creative", nil + case strings.Contains(m, "stable-conservative-upscale"): + return "upscale-conservative", nil + case strings.Contains(m, "stable-fast-upscale"): + return "upscale-fast", nil + case strings.Contains(m, "stable-image-inpaint"): + return "inpaint", nil + case strings.Contains(m, "stable-outpaint"): + return "outpaint", nil + case strings.Contains(m, "stable-image-search-recolor"): + return "recolor", nil + case strings.Contains(m, "stable-image-search-replace"): + return "search-replace", nil + case strings.Contains(m, "stable-image-erase-object"): + return "erase-object", nil + case strings.Contains(m, "stable-image-remove-background"): + return "remove-bg", nil + case strings.Contains(m, "stable-image-control-sketch"): + return "control-sketch", nil + case strings.Contains(m, "stable-image-control-structure"): + return "control-structure", nil + case strings.Contains(m, "stable-image-style-guide"): + return "style-guide", nil + case strings.Contains(m, "stable-style-transfer"): + return "style-transfer", nil + default: + return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model) + } +} + +// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request +// format used by Bedrock edit models. Only fields valid for the detected task type are populated. +// deployment is the resolved model identifier (after applying any deployment alias mapping); it is +// used for task-type inference so that alias-mapped models route correctly. +func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) { + if request == nil || request.Input == nil { + return nil, fmt.Errorf("request or input is nil") + } + + var taskType string + if request.Params != nil && request.Params.Type != nil { + taskType = getStabilityAITaskTypeFromParams(*request.Params.Type) + } + if taskType == "" { + var err error + taskType, err = getStabilityAIEditTaskType(deployment) + if err != nil { + return nil, err + } + } + + req := &StabilityAIImageEditRequest{} + + // Image sourcing + if taskType == "style-transfer" { + if len(request.Input.Images) != 2 { + return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image") + } + if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 { + return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image") + } + initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image) + styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image) + req.InitImage = &initB64 + req.StyleImage = &styleB64 + } else { + if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 { + return nil, fmt.Errorf("at least one image is required") + } + imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image) + req.Image = &imageB64 + } + + // Common fields populated based on task allowlist + prompt := request.Input.Prompt + switch taskType { + case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure", + "style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer": + req.Prompt = &prompt + } + + // Negative prompt + if request.Params != nil && request.Params.NegativePrompt != nil { + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch", + "control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer": + req.NegativePrompt = request.Params.NegativePrompt + } + } + + // Seed + if request.Params != nil && request.Params.Seed != nil { + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch", + "control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer": + req.Seed = request.Params.Seed + } + } + + // Mask (from Params.Mask bytes) + if request.Params != nil && len(request.Params.Mask) > 0 { + switch taskType { + case "inpaint", "erase-object": + maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask) + req.Mask = &maskB64 + } + } + + // ExtraParams + if request.Params != nil { + // Typed OutputFormat takes priority over ExtraParams + if request.Params.OutputFormat != nil { + req.OutputFormat = request.Params.OutputFormat + } + + if request.Params.ExtraParams != nil { + ep := make(map[string]interface{}, len(request.Params.ExtraParams)) + for k, v := range request.Params.ExtraParams { + ep[k] = v + } + + // output_format β€” all tasks (fallback if not already set by typed field) + if req.OutputFormat == nil { + if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok { + delete(ep, "output_format") + req.OutputFormat = v + } + } + + // style_preset + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch", + "control-structure", "style-guide", "upscale-creative": + if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok { + delete(ep, "style_preset") + req.StylePreset = v + } + } + + // grow_mask + switch taskType { + case "inpaint", "recolor", "search-replace", "erase-object": + if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok { + delete(ep, "grow_mask") + req.GrowMask = v + } + } + + // outpaint directional fields + if taskType == "outpaint" { + if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok { + delete(ep, "left") + req.Left = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok { + delete(ep, "right") + req.Right = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok { + delete(ep, "up") + req.Up = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok { + delete(ep, "down") + req.Down = v + } + } + + // creativity + switch taskType { + case "upscale-creative", "upscale-conservative", "outpaint": + if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok { + delete(ep, "creativity") + req.Creativity = v + } + } + + // select_prompt (recolor) + if taskType == "recolor" { + if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok { + delete(ep, "select_prompt") + req.SelectPrompt = v + } + } + + // search_prompt (search-replace) + if taskType == "search-replace" { + if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok { + delete(ep, "search_prompt") + req.SearchPrompt = v + } + } + + // control_strength + switch taskType { + case "control-sketch", "control-structure": + if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok { + delete(ep, "control_strength") + req.ControlStrength = v + } + } + + // style-guide fields + if taskType == "style-guide" { + if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok { + delete(ep, "aspect_ratio") + req.AspectRatio = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok { + delete(ep, "fidelity") + req.Fidelity = v + } + } + + // style-transfer fields + if taskType == "style-transfer" { + if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok { + delete(ep, "style_strength") + req.StyleStrength = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok { + delete(ep, "composition_fidelity") + req.CompositionFidelity = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok { + delete(ep, "change_strength") + req.ChangeStrength = v + } + } + + req.ExtraParams = ep + } + } + + // Validate required per-task fields + if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") { + return nil, fmt.Errorf("select_prompt is required for stability ai recolor task") + } + if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") { + return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task") + } + + return req, nil +} + // ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse { if response == nil { @@ -366,6 +725,13 @@ func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) bifrostResponse := &schemas.BifrostImageGenerationResponse{} + if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 { + bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{ + FinishReasons: append([]*string(nil), response.FinishReasons...), + Seeds: append([]int(nil), response.Seeds...), + } + } + for index, image := range response.Images { bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{ B64JSON: image, diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index c21520edc4..600a44d0e2 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -2,8 +2,10 @@ package bedrock import ( "bytes" + "encoding/base64" "encoding/json" "fmt" + "net/url" "strings" "github.com/bytedance/sonic" @@ -44,6 +46,17 @@ var bedrockInvokeRequestKnownFields = map[string]bool{ "message": true, "chat_history": true, // AI21 "n": true, "frequency_penalty": true, "presence_penalty": true, + // Bedrock image gen / edit / variation (Titan/Nova Canvas) + "taskType": true, "textToImageParams": true, "imageVariationParams": true, + "inPaintingParams": true, "outPaintingParams": true, "backgroundRemovalParams": true, + "imageGenerationConfig": true, + // Stability AI image + "image": true, "mask": true, "negative_prompt": true, + "aspect_ratio": true, "output_format": true, "seed": true, + // Embeddings + "inputText": true, "texts": true, "input_type": true, + "normalize": true, "dimensions": true, + "embedding_types": true, "output_dimension": true, "inputs": true, // Internal "stream": true, "extra_params": true, } @@ -125,17 +138,74 @@ func (r *BedrockInvokeRequest) UnmarshalJSON(data []byte) error { return nil } -// DetectInvokeRequestType determines the request type from raw JSON body -// without full deserialization, keeping detection logic colocated with IsMessagesRequest. -func DetectInvokeRequestType(body []byte) schemas.RequestType { - node, _ := sonic.Get(body, "messages") - if node.Exists() { - raw, err := node.Raw() - if err == nil && raw != "null" && raw != "[]" { +// DetectInvokeRequestType determines the request type from raw JSON body and model ID +// without full deserialization, keeping detection logic colocated with conversion methods. +func DetectInvokeRequestType(body []byte, modelID string) schemas.RequestType { + // Messages β†’ chat/responses path + if node, _ := sonic.Get(body, "messages"); node.Exists() { + if raw, err := node.Raw(); err == nil && raw != "null" && raw != "[]" { return schemas.ResponsesRequest } + } + + // Titan uses "inputText" for both embeddings and text generation. + // Use the model ID to disambiguate: embedding models contain "embed". + if node, _ := sonic.Get(body, "inputText"); node.Exists() { + if strings.Contains(strings.ToLower(modelID), "embed") { + return schemas.EmbeddingRequest + } return schemas.TextCompletionRequest } + + // Cohere embedding: text-only (texts), image-only (images), or mixed (inputs). + // Use model ID to identify embed models, then check for any non-empty payload field. + if strings.Contains(strings.ToLower(modelID), "embed") { + for _, field := range []string{"texts", "images", "inputs"} { + if node, _ := sonic.Get(body, field); node.Exists() { + if raw, err := node.Raw(); err == nil && raw != "null" && raw != "[]" { + return schemas.EmbeddingRequest + } + } + } + } + + // taskType-based image routing + if taskNode, _ := sonic.Get(body, "taskType"); taskNode.Exists() { + taskType, _ := taskNode.String() + switch taskType { + case TaskTypeTextImage: + return schemas.ImageGenerationRequest + case TaskTypeImageVariation: + return schemas.ImageVariationRequest + case TaskTypeInpainting, TaskTypeOutpainting, TaskTypeBackgroundRemoval: + return schemas.ImageEditRequest + } + } + + // URL-decode the model ID once for all model-name checks below + decodedModelID := modelID + if unescaped, err := url.PathUnescape(modelID); err == nil { + decodedModelID = unescaped + } + + // Stability AI: supports both generation (prompt-only) and edit (image+prompt) + if isStabilityAIModel(decodedModelID) { + if node, _ := sonic.Get(body, "image"); node.Exists() { + return schemas.ImageEditRequest + } + return schemas.ImageGenerationRequest + } + + // explicit image field -> edit request + if node, _ := sonic.Get(body, "image"); node.Exists() { + return schemas.ImageEditRequest + } + + // Checked after all body-field and model-specific signals so it doesn't shadow known models. + if isPromptOnlyImageGenerationModel(decodedModelID) { + return schemas.ImageGenerationRequest + } + return schemas.TextCompletionRequest } @@ -310,6 +380,382 @@ func (r *BedrockInvokeRequest) ToBifrostTextCompletionRequest(ctx *schemas.Bifro return textReq.ToBifrostTextCompletionRequest(ctx) } +// ToBifrostEmbeddingRequest converts the invoke request to a BifrostEmbeddingRequest. +// Handles both Titan (inputText) and Cohere (texts) embedding formats. +func (r *BedrockInvokeRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostEmbeddingRequest{ + Provider: provider, + Model: model, + } + + if r.InputText != "" { + req.Input = &schemas.EmbeddingInput{Text: &r.InputText} + } else if len(r.Texts) > 0 { + req.Input = &schemas.EmbeddingInput{Texts: r.Texts} + } + // image-only (r.Images) or mixed (r.Inputs): req.Input stays nil; data flows via ExtraParams + + extraParams := make(map[string]interface{}) + // Forward known embedding-only params into ExtraParams so the provider can pick them up + if r.InputType != nil { + extraParams["input_type"] = *r.InputType + } + if r.Normalize != nil { + extraParams["normalize"] = *r.Normalize + } + if len(r.EmbeddingTypes) > 0 { + extraParams["embedding_types"] = r.EmbeddingTypes + } + if r.Truncate != nil { + extraParams["truncate"] = *r.Truncate + } + if len(r.Images) > 0 { + extraParams["images"] = r.Images + } + if len(r.Inputs) > 0 { + extraParams["inputs"] = r.Inputs + } + if r.MaxTokens != nil { + extraParams["max_tokens"] = *r.MaxTokens + } + // Merge any remaining extra params from the request + for k, v := range r.ExtraParams { + extraParams[k] = v + } + + // output_dimension maps to Dimensions; prefer OutputDimension over Dimensions + dimensions := r.Dimensions + if r.OutputDimension != nil { + dimensions = r.OutputDimension + } + params := &schemas.EmbeddingParameters{ + Dimensions: dimensions, + } + if len(extraParams) > 0 { + params.ExtraParams = extraParams + } + req.Params = params + + return req +} + +// ToBifrostImageGenerationRequest converts the invoke request to a BifrostImageGenerationRequest. +// Handles Titan/Nova Canvas (taskType=TEXT_IMAGE with textToImageParams) and Stability AI (flat prompt fields). +func (r *BedrockInvokeRequest) ToBifrostImageGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageGenerationRequest { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageGenerationRequest{ + Provider: provider, + Model: model, + } + + params := &schemas.ImageGenerationParameters{ + NegativePrompt: r.NegativePrompt, + AspectRatio: r.AspectRatio, + N: r.N, + OutputFormat: r.OutputFormat, + Seed: r.Seed, + } + + if r.TextToImageParams != nil { + // Titan / Nova Canvas path + req.Input = &schemas.ImageGenerationInput{Prompt: r.TextToImageParams.Text} + if r.TextToImageParams.NegativeText != nil { + params.NegativePrompt = r.TextToImageParams.NegativeText + } + if r.TextToImageParams.Style != nil { + params.Style = r.TextToImageParams.Style + } + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + params.Seed = cfg.Seed + params.Quality = cfg.Quality + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.CfgScale != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + params.ExtraParams["cfgScale"] = *cfg.CfgScale + } + } + } else { + // Stability AI path β€” prompt comes from the top-level "prompt" field + req.Input = &schemas.ImageGenerationInput{Prompt: r.Prompt} + } + + // Forward any remaining ExtraParams + if len(r.ExtraParams) > 0 { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + for k, v := range r.ExtraParams { + params.ExtraParams[k] = v + } + } + + req.Params = params + return req +} + +// ToBifrostImageEditRequest converts the invoke request to a BifrostImageEditRequest. +// Handles Titan/Nova Canvas (taskType in INPAINTING/OUTPAINTING/BACKGROUND_REMOVAL) and Stability AI (flat image/mask fields). +func (r *BedrockInvokeRequest) ToBifrostImageEditRequest(ctx *schemas.BifrostContext) (*schemas.BifrostImageEditRequest, error) { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageEditRequest{ + Provider: provider, + Model: model, + } + params := &schemas.ImageEditParameters{ + NegativePrompt: r.NegativePrompt, + Seed: r.Seed, + } + + if r.TaskType != nil { + // Titan / Nova Canvas path + switch *r.TaskType { + case TaskTypeInpainting: + if r.InPaintingParams == nil { + return nil, fmt.Errorf("inPaintingParams required for INPAINTING task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.InPaintingParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode inpainting image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.InPaintingParams.Text, + } + params.Type = schemas.Ptr("inpainting") + if r.InPaintingParams.NegativeText != nil { + params.NegativePrompt = r.InPaintingParams.NegativeText + } + if r.InPaintingParams.MaskImage != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.InPaintingParams.MaskImage) + if err != nil { + return nil, fmt.Errorf("failed to decode inpainting mask: %w", err) + } + params.Mask = maskBytes + } + if r.InPaintingParams.MaskPrompt != nil || r.InPaintingParams.ReturnMask != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + if r.InPaintingParams.MaskPrompt != nil { + params.ExtraParams["mask_prompt"] = *r.InPaintingParams.MaskPrompt + } + if r.InPaintingParams.ReturnMask != nil { + params.ExtraParams["return_mask"] = *r.InPaintingParams.ReturnMask + } + } + + case TaskTypeOutpainting: + if r.OutPaintingParams == nil { + return nil, fmt.Errorf("outPaintingParams required for OUTPAINTING task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.OutPaintingParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode outpainting image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.OutPaintingParams.Text, + } + params.Type = schemas.Ptr("outpainting") + if r.OutPaintingParams.NegativeText != nil { + params.NegativePrompt = r.OutPaintingParams.NegativeText + } + if r.OutPaintingParams.MaskImage != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.OutPaintingParams.MaskImage) + if err != nil { + return nil, fmt.Errorf("failed to decode outpainting mask: %w", err) + } + params.Mask = maskBytes + } + if r.OutPaintingParams.MaskPrompt != nil || r.OutPaintingParams.ReturnMask != nil || r.OutPaintingParams.OutPaintingMode != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + if r.OutPaintingParams.MaskPrompt != nil { + params.ExtraParams["mask_prompt"] = *r.OutPaintingParams.MaskPrompt + } + if r.OutPaintingParams.ReturnMask != nil { + params.ExtraParams["return_mask"] = *r.OutPaintingParams.ReturnMask + } + if r.OutPaintingParams.OutPaintingMode != nil { + params.ExtraParams["outpainting_mode"] = *r.OutPaintingParams.OutPaintingMode + } + } + + case TaskTypeBackgroundRemoval: + if r.BackgroundRemovalParams == nil { + return nil, fmt.Errorf("backgroundRemovalParams required for BACKGROUND_REMOVAL task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.BackgroundRemovalParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode background removal image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + } + params.Type = schemas.Ptr("background_removal") + + default: + return nil, fmt.Errorf("unsupported taskType for image edit: %s", *r.TaskType) + } + + // Map imageGenerationConfig fields into edit params (Titan/Nova Canvas only) + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + params.Seed = cfg.Seed + params.Quality = cfg.Quality + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.CfgScale != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + params.ExtraParams["cfgScale"] = *cfg.CfgScale + } + } + } else { + // Stability AI path + if r.Image == nil { + return nil, fmt.Errorf("image field is required for Stability AI image edit") + } + imgBytes, err := base64.StdEncoding.DecodeString(*r.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode stability AI image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.Prompt, + } + // Infer task type from model name + taskType, err := getStabilityAIEditTaskType(r.ModelID) + if err != nil { + return nil, fmt.Errorf("cannot determine Stability AI edit task: %w", err) + } + params.Type = &taskType + if r.Mask != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.Mask) + if err != nil { + return nil, fmt.Errorf("failed to decode stability AI mask: %w", err) + } + params.Mask = maskBytes + } + } + + if len(r.ExtraParams) > 0 { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}, len(r.ExtraParams)) + } + for k, v := range r.ExtraParams { + params.ExtraParams[k] = v + } + } + req.Params = params + return req, nil +} + +// ToBifrostImageVariationRequest converts the invoke request to a BifrostImageVariationRequest. +// Reads from imageVariationParams (Titan/Nova Canvas format). +func (r *BedrockInvokeRequest) ToBifrostImageVariationRequest(ctx *schemas.BifrostContext) (*schemas.BifrostImageVariationRequest, error) { + if r.ImageVariationParams == nil || len(r.ImageVariationParams.Images) == 0 { + return nil, fmt.Errorf("imageVariationParams.images is required for IMAGE_VARIATION") + } + + primaryBytes, err := base64.StdEncoding.DecodeString(r.ImageVariationParams.Images[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode primary variation image: %w", err) + } + + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageVariationRequest{ + Provider: provider, + Model: model, + Input: &schemas.ImageVariationInput{ + Image: schemas.ImageInput{Image: primaryBytes}, + }, + } + + params := &schemas.ImageVariationParameters{} + extraParams := make(map[string]interface{}) + + // Additional images (index 1+) stored under "images" key for the provider + if len(r.ImageVariationParams.Images) > 1 { + additionalImages := make([][]byte, 0, len(r.ImageVariationParams.Images)-1) + for _, imgB64 := range r.ImageVariationParams.Images[1:] { + imgBytes, err := base64.StdEncoding.DecodeString(imgB64) + if err != nil { + return nil, fmt.Errorf("failed to decode additional variation image: %w", err) + } + additionalImages = append(additionalImages, imgBytes) + } + extraParams["images"] = additionalImages + } + + // Text / negative text / similarity strength go to ExtraParams (provider reads them from there) + if r.ImageVariationParams.Text != nil { + extraParams["prompt"] = *r.ImageVariationParams.Text + } + if r.ImageVariationParams.NegativeText != nil { + extraParams["negativeText"] = *r.ImageVariationParams.NegativeText + } + if r.ImageVariationParams.SimilarityStrength != nil { + extraParams["similarityStrength"] = *r.ImageVariationParams.SimilarityStrength + } + + // ImageGenerationConfig β†’ N, Size, Seed, Quality, CfgScale + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.Seed != nil { + extraParams["seed"] = *cfg.Seed + } + if cfg.Quality != nil { + extraParams["quality"] = *cfg.Quality + } + if cfg.CfgScale != nil { + extraParams["cfgScale"] = *cfg.CfgScale + } + } + + // Forward any remaining ExtraParams from the request body + for k, v := range r.ExtraParams { + extraParams[k] = v + } + if len(extraParams) > 0 { + params.ExtraParams = extraParams + } + + req.Params = params + return req, nil +} + // buildCohereCommandRPrompt converts Cohere Command R's message + chat_history into a text prompt. func (r *BedrockInvokeRequest) buildCohereCommandRPrompt() string { var sb strings.Builder @@ -448,9 +894,16 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return nil, fmt.Errorf("bifrost response is nil") } - model := resp.Model - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested + model := "" + if resp.Model != "" { + model = resp.Model + } else { + extraFields := resp.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } } // Nova models: delegate to existing ToBedrockConverseResponse (Nova InvokeModel matches Converse format) @@ -467,6 +920,101 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return toBedrockInvokeAnthropicResponse(resp, model), nil } +func ToBedrockInvokeImagesResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostImageGenerationResponse) (interface{}, error) { + if resp == nil { + return nil, fmt.Errorf("bifrost response is nil") + } + + // If the provider stored the raw Bedrock response, return it verbatim (preserves seeds, finish_reasons, etc.) + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + + model := resp.Model + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } + } + + // Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas + if isStabilityAIModel(model) { + return ToStabilityAIImageGenerationResponse(resp) + } + + // Default: Titan Image Generator v1/v2, Nova Canvas β€” reconstruct from Bifrost data + result := &BedrockImageGenerationResponse{} + for _, d := range resp.Data { + result.Images = append(result.Images, d.B64JSON) + } + return result, nil +} + +// ToBedrockEmbeddingInvokeResponse converts a BifrostEmbeddingResponse back to the native +// Bedrock invoke API response format. +// Single-embedding (Titan) responses use: {"embedding": [...], "inputTextTokenCount": N} +// Multi-embedding (Cohere) responses use: {"embeddings": [[...],[...]], "response_type": "embeddings_floats"} +func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + if resp == nil { + return nil, fmt.Errorf("bifrost embedding response is nil") + } + + // If the provider stored the raw Bedrock response, return it verbatim + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + + tokenCount := 0 + if resp.Usage != nil { + tokenCount = resp.Usage.PromptTokens + } + + if len(resp.Data) == 0 { + return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil + } + + // Use model name to distinguish Cohere from Titan β€” not batch size. + // A single-input Cohere request must still return the Cohere envelope format. + model := resp.Model + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } + } + + if strings.Contains(strings.ToLower(model), "cohere") { + floats := make([][]float32, 0, len(resp.Data)) + for _, d := range resp.Data { + float32Emb := make([]float32, len(d.Embedding.EmbeddingArray)) + for i, v := range d.Embedding.EmbeddingArray { + float32Emb[i] = float32(v) + } + floats = append(floats, float32Emb) + } + return &BedrockInvokeCohereEmbeddingResp{ + Embeddings: floats, + ResponseType: "embeddings_floats", + }, nil + } + + // Titan format + if resp.Data[0].Embedding.EmbeddingArray == nil { + return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil + } + float32Emb := make([]float32, len(resp.Data[0].Embedding.EmbeddingArray)) + for i, v := range resp.Data[0].Embedding.EmbeddingArray { + float32Emb[i] = float32(v) + } + return &BedrockInvokeEmbeddingResp{ + Embedding: float32Emb, + InputTextTokenCount: tokenCount, + }, nil +} + // toBedrockInvokeAnthropicResponse converts BifrostResponsesResponse to Anthropic Messages API format. func toBedrockInvokeAnthropicResponse(resp *schemas.BifrostResponsesResponse, model string) *BedrockInvokeMessagesResponse { result := &BedrockInvokeMessagesResponse{ @@ -623,12 +1171,17 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc // final Completed event). Without checking resp.ExtraFields, early chunks would // have model="" and Nova streams would be mis-routed through the Anthropic path. model := "" - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.ExtraFields.ModelRequested != "" { - model = resp.Response.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.Model != "" { - model = resp.Response.Model + if resp.Response != nil { + if resp.Response.Model != "" { + model = resp.Response.Model + } else { + extraFields := resp.Response.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } + } } // Nova models: delegate to existing converse stream response (same format) @@ -656,6 +1209,7 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc bedrockEvent := &BedrockStreamEvent{ InvokeModelRawChunk: rawBytes, } + return "", bedrockEvent, nil } @@ -666,8 +1220,11 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) switch resp.Type { case schemas.ResponsesStreamResponseTypeCreated: - // message_start β€” use ExtraFields.ModelRequested as fallback for early chunks - model := resp.ExtraFields.ModelRequested + // message_start β€” prefer resolved model for accurate family detection on early chunks + model := resp.ExtraFields.ResolvedModelUsed + if model == "" { + model = resp.ExtraFields.OriginalModelRequested + } msgStart := map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ @@ -777,7 +1334,7 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) "type": "content_block_delta", "index": idx, "delta": map[string]interface{}{ - "type": "input_json_delta", + "type": "input_json_delta", "partial_json": *resp.Delta, }, } diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index e4e96a8017..6d2f9006f2 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -1,7 +1,6 @@ package bedrock import ( - "slices" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -82,147 +81,8 @@ type BedrockRerankResponseDocument struct { TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"` } -// regionPrefixes is a list of region prefixes used in Bedrock deployments -// Based on AWS region naming patterns and Bedrock deployment configurations -var regionPrefixes = []string{ - "us.", // US regions (us-east-1, us-west-2, etc.) - "eu.", // Europe regions (eu-west-1, eu-central-1, etc.) - "ap.", // Asia Pacific regions (ap-southeast-1, ap-northeast-1, etc.) - "ca.", // Canada regions (ca-central-1, etc.) - "sa.", // South America regions (sa-east-1, etc.) - "af.", // Africa regions (af-south-1, etc.) - "global.", // Global deployment prefix -} - -// extractPrefix extracts the region prefix ending with '.' from a string -// Only recognizes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the prefix (including the dot) if found, empty string otherwise -func extractPrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return prefix - } - } - return "" -} - -// removePrefix removes any region prefix ending with '.' from a string -// Only removes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the string without the prefix -func removePrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return s[len(prefix):] - } - } - return s -} - -// findMatchingAllowedModel finds a matching item in a slice, considering both -// exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact matches - if slices.Contains(slice, value) { - return value - } - - // Check with region prefix added/removed - valuePrefix := extractPrefix(value) - if valuePrefix != "" { - // value has a prefix, check if slice contains version without prefix - withoutPrefix := removePrefix(value) - if slices.Contains(slice, withoutPrefix) { - return withoutPrefix - } - } - - // Check if any item in slice has a prefix that matches value without prefix - for _, item := range slice { - itemPrefix := extractPrefix(item) - if itemPrefix != "" { - // item has prefix, check if value matches without the prefix - itemWithoutPrefix := removePrefix(item) - if itemWithoutPrefix == value { - return item - } - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize value by removing any region prefix for base model comparison - valueNormalized := removePrefix(value) - - for _, item := range slice { - // Normalize item by removing any region prefix for base model comparison - itemNormalized := removePrefix(item) - - // Check base model match with normalized values (prefix removed from both) - // Return the item from slice (not value) to use the actual name from allowedModels - if schemas.SameBaseModel(itemNormalized, valueNormalized) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// The modelID from the API response should match a deployment value (not the alias/key). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check if any deployment value matches the modelID (with or without prefix) - for aliasKey, deploymentValue := range deployments { - // Exact match - if deploymentValue == modelID || aliasKey == modelID { - return deploymentValue, aliasKey - } - - // Check prefix variations - deploymentPrefix := extractPrefix(deploymentValue) - modelIDPrefix := extractPrefix(modelID) - aliasKeyPrefix := extractPrefix(aliasKey) - - // Case 1: deploymentValue or aliasKey has prefix, modelID doesn't - if (deploymentPrefix != "" && modelIDPrefix == "") || (aliasKeyPrefix != "" && modelIDPrefix == "") { - if removePrefix(deploymentValue) == modelID || removePrefix(aliasKey) == modelID { - return deploymentValue, aliasKey - } - } - - // Case 2: modelID or aliasKey has prefix, deploymentValue doesn't - if (modelIDPrefix != "" && deploymentPrefix == "") || (aliasKeyPrefix != "" && deploymentPrefix == "") { - if removePrefix(modelID) == deploymentValue || removePrefix(modelID) == aliasKey { - return deploymentValue, aliasKey - } - } - - // Case 3: Both have prefixes but different prefixes - if (deploymentPrefix != "" && modelIDPrefix != "" && deploymentPrefix != modelIDPrefix) || (aliasKeyPrefix != "" && modelIDPrefix != "" && aliasKeyPrefix != modelIDPrefix) { - if removePrefix(deploymentValue) == removePrefix(modelID) || removePrefix(aliasKey) == removePrefix(modelID) { - return deploymentValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize both values by removing any region prefix for base model comparison - deploymentNormalized := removePrefix(deploymentValue) - modelIDNormalized := removePrefix(modelID) - - // Check base model match with normalized values (prefix removed from both) - if schemas.SameBaseModel(deploymentNormalized, modelIDNormalized) { - return deploymentValue, aliasKey - } - } - return "", "" -} -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -231,121 +91,41 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Data: make([]schemas.Model, 0, len(response.ModelSummaries)), } - deploymentValues := make([]string, 0, len(deployments)) - for _, deployment := range deployments { - deploymentValues = append(deploymentValues, deployment) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } - - includedModels := make(map[string]bool) - for _, model := range response.ModelSummaries { - modelID := model.ModelID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering global prefix variations - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deploymentValues) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deploymentValues) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ModelID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ModelID, modelID, deploymentAlias, matchedAllowedModel) { - continue - } - - modelEntry := schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.ModelName), - OwnedBy: schemas.Ptr(model.ProviderName), - Architecture: &schemas.Architecture{ - InputModalities: model.InputModalities, - OutputModalities: model.OutputModalities, - }, - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(providerKey) + "/" + deploymentAlias - // Use the actual deployment value (which might have global prefix) - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true - } else { - includedModels[modelID] = true - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[alias] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[alias] = true - } - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { - continue + for _, model := range response.ModelSummaries { + for _, result := range pipeline.FilterModel(model.ModelID) { + modelEntry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.ModelName), + OwnedBy: schemas.Ptr(model.ProviderName), + Architecture: &schemas.Architecture{ + InputModalities: model.InputModalities, + OutputModalities: model.OutputModalities, + }, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/bedrock/rerank_test.go b/core/providers/bedrock/rerank_test.go index 0dff5c3ee2..c1b7bb5480 100644 --- a/core/providers/bedrock/rerank_test.go +++ b/core/providers/bedrock/rerank_test.go @@ -195,27 +195,23 @@ func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) { func TestResolveBedrockDeployment(t *testing.T) { key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", }, } - deployment := resolveBedrockDeployment("cohere-rerank", key) + deployment := key.Aliases.Resolve("cohere-rerank") assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment) - assert.Equal(t, "cohere.rerank-v3-5:0", resolveBedrockDeployment("cohere.rerank-v3-5:0", key)) - assert.Equal(t, "", resolveBedrockDeployment("", key)) + assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0")) + assert.Equal(t, "", key.Aliases.Resolve("")) } func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) { provider := &BedrockProvider{} ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "cohere.rerank-v3-5:0", }, } diff --git a/core/providers/bedrock/s3.go b/core/providers/bedrock/s3.go index da06e5e820..be2d0afb32 100644 --- a/core/providers/bedrock/s3.go +++ b/core/providers/bedrock/s3.go @@ -22,7 +22,6 @@ func uploadToS3( region string, bucket, key string, content []byte, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Create AWS config with credentials var cfg aws.Config @@ -47,7 +46,7 @@ func uploadToS3( } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load AWS config for S3", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err) } // Create S3 client @@ -62,7 +61,7 @@ func uploadToS3( }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to S3: %s/%s", bucket, key), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err) } return nil diff --git a/core/providers/bedrock/signer.go b/core/providers/bedrock/signer.go index 9f12e3bbaf..b7e87ae8d2 100644 --- a/core/providers/bedrock/signer.go +++ b/core/providers/bedrock/signer.go @@ -280,17 +280,16 @@ func signAWSRequestFastHTTP( accessKey, secretKey string, sessionToken *string, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Get AWS credentials if not provided if accessKey == "" && secretKey == "" { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } accessKey = creds.AccessKeyID secretKey = creds.SecretAccessKey diff --git a/core/providers/bedrock/text.go b/core/providers/bedrock/text.go index 6ad24ee1c8..d31d716ded 100644 --- a/core/providers/bedrock/text.go +++ b/core/providers/bedrock/text.go @@ -127,8 +127,6 @@ func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -154,8 +152,6 @@ func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *s Object: "text_completion", Choices: choices, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -167,11 +163,14 @@ func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionR return nil } - // Determine response format based on model - // Use ModelRequested from ExtraFields if available, otherwise use Model + // Determine response format based on resolved model identity. + // Use ResolvedModelUsed (actual provider ID) for accurate family detection, + // falling back to bifrostResp.Model, then OriginalModelRequested as a last resort. model := bifrostResp.Model - if bifrostResp.ExtraFields.ModelRequested != "" { - model = bifrostResp.ExtraFields.ModelRequested + if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" { + model = bifrostResp.ExtraFields.OriginalModelRequested } if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") { diff --git a/core/providers/bedrock/transport_test.go b/core/providers/bedrock/transport_test.go index 6751527b5b..1e2a447e9d 100644 --- a/core/providers/bedrock/transport_test.go +++ b/core/providers/bedrock/transport_test.go @@ -138,7 +138,7 @@ func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - _, _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") + _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") require.NotNil(t, bifrostErr, "expected error when server closes connection") assert.False(t, bifrostErr.IsBifrostError, diff --git a/core/providers/bedrock/types.go b/core/providers/bedrock/types.go index 388571cc38..6a9e8f79fb 100644 --- a/core/providers/bedrock/types.go +++ b/core/providers/bedrock/types.go @@ -651,10 +651,10 @@ type BedrockMetadataEvent struct { // BedrockTitanEmbeddingRequest represents a Bedrock Titan embedding request type BedrockTitanEmbeddingRequest struct { - InputText string `json:"inputText"` // Required: Text to embed + InputText string `json:"inputText"` // Required: Text to embed + Dimensions *int `json:"dimensions,omitempty"` // Optional: 256, 512, or 1024 (titan-embed-text-v2 only) + Normalize *bool `json:"normalize,omitempty"` // Optional: normalize the embedding ExtraParams map[string]interface{} `json:"-"` - // Note: Titan models have fixed dimensions and don't support the dimensions parameter - // ExtraParams can be used for any additional model-specific parameters } // GetExtraParams implements the RequestBodyWithExtraParams interface @@ -668,6 +668,53 @@ type BedrockTitanEmbeddingResponse struct { InputTextTokenCount int `json:"inputTextTokenCount"` // Number of tokens in input } +// BedrockCohereEmbeddingContentBlock represents a single content block in a mixed input +type BedrockCohereEmbeddingContentBlock struct { + Type string `json:"type"` // "text" or "image_url" + Text *string `json:"text,omitempty"` // for type=text + ImageURL *BedrockCohereEmbeddingImageURL `json:"image_url,omitempty"` // for type=image_url +} + +// BedrockCohereEmbeddingImageURL holds the URL for an image content block +type BedrockCohereEmbeddingImageURL struct { + URL string `json:"url"` +} + +// BedrockCohereEmbeddingInput represents a mixed text+image input +type BedrockCohereEmbeddingInput struct { + Content []BedrockCohereEmbeddingContentBlock `json:"content"` +} + +// BedrockCohereEmbeddingRequest represents a Bedrock Cohere embedding request. +// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the body. +type BedrockCohereEmbeddingRequest struct { + InputType string `json:"input_type"` // Required + Texts []string `json:"texts,omitempty"` // text-only inputs + Images []string `json:"images,omitempty"` // image-only inputs (data URIs) + Inputs []BedrockCohereEmbeddingInput `json:"inputs,omitempty"` // mixed text+image inputs + EmbeddingTypes []string `json:"embedding_types,omitempty"` // e.g. ["float"] + OutputDimension *int `json:"output_dimension,omitempty"` // 256, 512, 1024, or 1536 + MaxTokens *int `json:"max_tokens,omitempty"` // max 128000 + Truncate *string `json:"truncate,omitempty"` // NONE, LEFT, or RIGHT + ExtraParams map[string]interface{} `json:"-"` +} + +// GetExtraParams implements the RequestBodyWithExtraParams interface +func (req *BedrockCohereEmbeddingRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// BedrockCohereEmbeddingResponse handles both Bedrock Cohere embedding response shapes. +// When embedding_types is not set, Bedrock returns embeddings as a raw [][]float32 +// ("embeddings_floats"). When embedding_types is set, it returns an object with typed +// arrays ("embeddings_by_type"). Using json.RawMessage defers parsing until we know the shape. +type BedrockCohereEmbeddingResponse struct { + ID string `json:"id"` + Embeddings json.RawMessage `json:"embeddings"` + ResponseType string `json:"response_type"` + Texts []string `json:"texts,omitempty"` +} + const TaskTypeTextImage = "TEXT_IMAGE" const TaskTypeImageVariation = "IMAGE_VARIATION" const TaskTypeInpainting = "INPAINTING" @@ -760,11 +807,79 @@ type BedrockBackgroundRemovalParams struct { Image string `json:"image"` // Base64-encoded image } -// BedrockImageGenerationResponse represents a Bedrock image generation response +// StabilityAIImageGenerationRequest represents the request format for Stability AI models on Bedrock +// (e.g. stability.stable-image-core-v1:1, stability.stable-image-ultra-v1:1) +type StabilityAIImageGenerationRequest struct { + Prompt string `json:"prompt"` + AspectRatio *string `json:"aspect_ratio,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + Seed *int `json:"seed,omitempty"` + NegativePrompt *string `json:"negative_prompt,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +// GetExtraParams implements the RequestBodyWithExtraParams interface +func (req *StabilityAIImageGenerationRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// StabilityAIImageEditRequest is the flat JSON body for Stability AI image-edit models on Bedrock. +// Only the fields valid for the detected task type are populated. +type StabilityAIImageEditRequest struct { + // Shared params + Image *string `json:"image,omitempty"` // base64, primary input image + Prompt *string `json:"prompt,omitempty"` + NegativePrompt *string `json:"negative_prompt,omitempty"` + Seed *int `json:"seed,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + StylePreset *string `json:"style_preset,omitempty"` + Mask *string `json:"mask,omitempty"` // base64 mask image + GrowMask *int `json:"grow_mask,omitempty"` + + // Outpaint + Left *int `json:"left,omitempty"` + Right *int `json:"right,omitempty"` + Up *int `json:"up,omitempty"` + Down *int `json:"down,omitempty"` + + // Upscale-creative / upscale-conservative / outpaint + Creativity *float64 `json:"creativity,omitempty"` + + // Recolor + SelectPrompt *string `json:"select_prompt,omitempty"` + + // Search-replace + SearchPrompt *string `json:"search_prompt,omitempty"` + + // Control-sketch / control-structure + ControlStrength *float64 `json:"control_strength,omitempty"` + + // Style-guide + AspectRatio *string `json:"aspect_ratio,omitempty"` + Fidelity *float64 `json:"fidelity,omitempty"` + + // Style-transfer (uses different image field names) + InitImage *string `json:"init_image,omitempty"` + StyleImage *string `json:"style_image,omitempty"` + StyleStrength *float64 `json:"style_strength,omitempty"` + CompositionFidelity *float64 `json:"composition_fidelity,omitempty"` + ChangeStrength *float64 `json:"change_strength,omitempty"` + + ExtraParams map[string]interface{} `json:"-"` +} + +func (req *StabilityAIImageEditRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// BedrockImageGenerationResponse represents a Bedrock image generation response. +// The Seeds and FinishReasons fields are populated by Stability AI edit models only. type BedrockImageGenerationResponse struct { - Images []string `json:"images"` // list of Base64 encoded images - MaskImage string `json:"maskImage"` // Base64 encoded mask image (optional) - Error string `json:"error"` // error message (if present) + Images []string `json:"images"` // list of Base64 encoded images + MaskImage string `json:"maskImage,omitempty"` // Base64 encoded mask image (optional) + Error string `json:"error,omitempty"` // error message (if present) + Seeds []int `json:"seeds,omitempty"` // Stability AI: seeds used per image + FinishReasons []*string `json:"finish_reasons,omitempty"` // Stability AI: finish reason per image (may be null) } // ==================== MODELS TYPES ==================== @@ -957,6 +1072,37 @@ type BedrockInvokeRequest struct { FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` PresencePenalty *float64 `json:"presence_penalty,omitempty"` + // ==================== BEDROCK IMAGE GEN / EDIT / VARIATION (Titan/Nova Canvas) ==================== + + TaskType *string `json:"taskType,omitempty"` + TextToImageParams *BedrockTextToImageParams `json:"textToImageParams,omitempty"` + ImageVariationParams *BedrockImageVariationParams `json:"imageVariationParams,omitempty"` + InPaintingParams *BedrockInPaintingParams `json:"inPaintingParams,omitempty"` + OutPaintingParams *BedrockOutPaintingParams `json:"outPaintingParams,omitempty"` + BackgroundRemovalParams *BedrockBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"` + ImageGenerationConfig *ImageGenerationConfig `json:"imageGenerationConfig,omitempty"` + + // ==================== STABILITY AI IMAGE ==================== + + // Image is the base64-encoded input image (SA edit / variation) + Image *string `json:"image,omitempty"` + Mask *string `json:"mask,omitempty"` // base64 mask for inpainting + NegativePrompt *string `json:"negative_prompt,omitempty"` // SA gen / edit + AspectRatio *string `json:"aspect_ratio,omitempty"` // SA gen + OutputFormat *string `json:"output_format,omitempty"` // SA gen + Seed *int `json:"seed,omitempty"` // SA gen / edit + + // ==================== EMBEDDINGS ==================== + + InputText string `json:"inputText,omitempty"` // Titan embed + Texts []string `json:"texts,omitempty"` // Cohere embed + InputType *string `json:"input_type,omitempty"` // Cohere embed + Normalize *bool `json:"normalize,omitempty"` // Titan embed v2 + Dimensions *int `json:"dimensions,omitempty"` // Titan embed v2 + EmbeddingTypes []string `json:"embedding_types,omitempty"` // Cohere embed: ["float","int8","uint8","binary","ubinary"] + OutputDimension *int `json:"output_dimension,omitempty"` // Cohere embed: 256, 512, 1024, 1536 + Inputs []BedrockCohereEmbeddingInput `json:"inputs,omitempty"` // Cohere embed: mixed text+image inputs + // ==================== INTERNAL ==================== Stream bool `json:"-"` ExtraParams map[string]interface{} `json:"-"` @@ -967,3 +1113,15 @@ type BedrockCohereRMessage struct { Role string `json:"role"` // "USER" or "CHATBOT" Message string `json:"message"` // Message content } + +// BedrockInvokeEmbeddingResp is the Titan single-embedding invoke response format. +type BedrockInvokeEmbeddingResp struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` +} + +// BedrockInvokeCohereEmbeddingResp is the Cohere multi-embedding invoke response format. +type BedrockInvokeCohereEmbeddingResp struct { + Embeddings [][]float32 `json:"embeddings"` + ResponseType string `json:"response_type"` +} diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index e880087e8b..3a000a76af 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -178,9 +178,6 @@ func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/cohere/chat.go b/core/providers/cohere/chat.go index 1623d23737..fc19d18919 100644 --- a/core/providers/cohere/chat.go +++ b/core/providers/cohere/chat.go @@ -372,8 +372,6 @@ func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas }, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Cohere, }, } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 2986760329..b67ef8668e 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -155,7 +155,7 @@ func (provider *CohereProvider) buildRequestURL(ctx *schemas.BifrostContext, def // completeRequest sends a request to Cohere's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -199,10 +199,10 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, latency, providerResponseHeaders, parseCohereError(resp, meta) + return nil, latency, providerResponseHeaders, parseCohereError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, latency, providerResponseHeaders, decodeErr } @@ -217,8 +217,6 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -234,7 +232,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Parse and add query parameters u, err := url.Parse(baseURL) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse request URL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse request url", err) } q := u.Query() @@ -269,15 +267,12 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseCohereError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse Cohere list models response @@ -288,7 +283,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := cohereResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -352,17 +347,12 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -375,9 +365,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -396,9 +383,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -424,7 +408,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, err } - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -435,8 +418,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } reqBody.Stream = schemas.Ptr(true) return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -486,9 +468,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -497,11 +479,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -520,9 +498,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -560,7 +538,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -582,11 +560,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -594,11 +567,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext if response != nil { response.ID = responseID response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -638,18 +608,13 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereResponsesRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Convert to Cohere v2 request - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -662,9 +627,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -685,9 +647,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -711,7 +670,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, err } - providerName := provider.GetProviderKey() // Convert to Cohere v2 request and add streaming jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -725,8 +683,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos reqBody.Stream = schemas.Ptr(true) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -774,9 +731,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -785,11 +742,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -808,9 +761,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -852,8 +805,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + provider.logger.Warn("Error reading stream: %v", readErr) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -873,11 +826,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos responses, bifrostErr, isLastChunk := event.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -886,11 +834,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() chunkIndex++ @@ -934,18 +879,13 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Create Bifrost request for conversion - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -958,9 +898,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -979,9 +916,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := response.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1010,17 +944,12 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereRerankRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1033,9 +962,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1056,9 +982,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1211,16 +1134,12 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereCountTokensRequest(request) - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1230,11 +1149,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonBody, provider.buildRequestURL(ctx, "/v1/tokenize", schemas.CountTokensRequest), key.Value.GetValue(), - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1248,9 +1162,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return &schemas.BifrostCountTokensResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1272,12 +1183,9 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch bifrostResponse := cohereResponse.ToBifrostCountTokensResponse(request.Model) if bifrostResponse == nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil Cohere count tokens response"), providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil cohere count tokens response")), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.CountTokensRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders diff --git a/core/providers/cohere/cohere_test.go b/core/providers/cohere/cohere_test.go index f2fca1028f..23c73911bc 100644 --- a/core/providers/cohere/cohere_test.go +++ b/core/providers/cohere/cohere_test.go @@ -32,27 +32,27 @@ func TestCohere(t *testing.T) { RerankModel: "rerank-v3.5", ReasoningModel: "command-a-reasoning-08-2025", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not typical for Cohere - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, // May not support automatic - ImageURL: false, // Supported by c4ai-aya-vision-8b model - ImageBase64: true, // Supported by c4ai-aya-vision-8b model - MultipleImages: false, // Supported by c4ai-aya-vision-8b model - FileBase64: false, // Not supported - FileURL: false, // Not supported - CompleteEnd2End: false, - Embedding: true, - Rerank: true, - Reasoning: true, - ListModels: true, - CountTokens: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, // May not support automatic + ImageURL: false, // Supported by c4ai-aya-vision-8b model + ImageBase64: true, // Supported by c4ai-aya-vision-8b model + MultipleImages: false, // Supported by c4ai-aya-vision-8b model + FileBase64: false, // Not supported + FileURL: false, // Not supported + CompleteEnd2End: false, + Embedding: true, + Rerank: true, + Reasoning: true, + ListModels: true, + CountTokens: true, }, } diff --git a/core/providers/cohere/errors.go b/core/providers/cohere/errors.go index e9183b1b34..e444d86650 100644 --- a/core/providers/cohere/errors.go +++ b/core/providers/cohere/errors.go @@ -6,7 +6,7 @@ import ( "github.com/valyala/fasthttp" ) -func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp CohereError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) bifrostErr.Type = &errorResp.Type @@ -17,10 +17,5 @@ func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetada if errorResp.Code != nil { bifrostErr.Error.Code = errorResp.Code } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index 3df2aab89a..3b285f97b6 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -2,8 +2,9 @@ package cohere import ( "encoding/json" - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -44,7 +45,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -53,37 +54,39 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - includedModels := make(map[string]bool) - for _, model := range response.Models { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.Name) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.Name) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.Name, - Name: schemas.Ptr(model.Name), - ContextLength: schemas.Ptr(int(model.ContextLength)), - SupportedMethods: model.Endpoints, - }) - includedModels[model.Name] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Models { + // Cohere uses model.Name as the model identifier + for _, result := range pipeline.FilterModel(model.Name) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + ContextLength: schemas.Ptr(int(model.ContextLength)), + SupportedMethods: model.Endpoints, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index afa095e56d..40ddbeb4ad 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "io" "mime/multipart" "net/http" @@ -74,8 +73,6 @@ func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -103,10 +100,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseElevenlabsError(resp) } var elevenlabsResponse ElevenlabsListModelsResponse @@ -115,7 +109,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, return nil, bifrostErr } - response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := elevenlabsResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -188,8 +182,6 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -211,7 +203,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche endpoint = "/v1/text-to-speech/" + voice } } else { - return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil) } requestURL := provider.buildBaseSpeechRequestURL(ctx, endpoint, schemas.SpeechRequest, request) @@ -228,8 +220,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -250,26 +241,18 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Get the response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response based on whether timestamps were requested bifrostResponse := &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -282,7 +265,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche if withTimestampsRequest { var timestampResponse ElevenlabsSpeechWithTimestampsResponse if err := sonic.Unmarshal(body, ×tampResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err) } bifrostResponse.AudioBase64 = ×tampResponse.AudioBase64 @@ -321,15 +304,12 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -345,7 +325,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) if request.Params == nil || request.Params.VoiceConfig == nil || request.Params.VoiceConfig.Voice == nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } req.SetRequestURI(provider.buildBaseSpeechRequestURL(ctx, "/v1/text-to-speech/"+*request.Params.VoiceConfig.Voice+"/stream", schemas.SpeechStreamRequest, request)) @@ -376,9 +356,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -387,11 +367,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response channel @@ -402,9 +378,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -445,7 +421,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -458,11 +434,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -481,11 +454,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDone, Audio: []byte{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -506,32 +476,30 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return nil, err } - providerName := provider.GetProviderKey() - reqBody := ToElevenlabsTranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil) } hasFile := len(reqBody.File) > 0 hasURL := reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" if hasFile && hasURL { - return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil) } if !hasFile && !hasURL { - return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil) } var body bytes.Buffer writer := multipart.NewWriter(&body) - if bifrostErr := writeTranscriptionMultipart(writer, reqBody, providerName); bifrostErr != nil { + if bifrostErr := writeTranscriptionMultipart(writer, reqBody); bifrostErr != nil { return nil, bifrostErr } contentType := writer.FormDataContentType() if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err) } req := fasthttp.AcquireRequest() @@ -562,17 +530,12 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + return nil, parseElevenlabsError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -588,18 +551,15 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k chunks, err := parseTranscriptionResponse(responseBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } if len(chunks) == 0 { - return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil) } response := ToBifrostTranscriptionResponse(chunks) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } @@ -607,7 +567,7 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + rawResponse = string(responseBody) } response.ExtraFields.RawResponse = rawResponse } @@ -615,9 +575,9 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return response, nil } -func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError { +func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest) *schemas.BifrostError { if err := writer.WriteField("model_id", reqBody.ModelID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model_id field", err) } if len(reqBody.File) > 0 { @@ -627,98 +587,98 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create file field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create file field", err) } if _, err := fileWriter.Write(reqBody.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } } if reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" { if err := writer.WriteField("cloud_storage_url", *reqBody.CloudStorageURL); err != nil { - return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err) } } if reqBody.LanguageCode != nil && strings.TrimSpace(*reqBody.LanguageCode) != "" { if err := writer.WriteField("language_code", *reqBody.LanguageCode); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language_code field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language_code field", err) } } if reqBody.TagAudioEvents != nil { if err := writer.WriteField("tag_audio_events", strconv.FormatBool(*reqBody.TagAudioEvents)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err) } } if reqBody.NumSpeakers != nil && *reqBody.NumSpeakers > 0 { if err := writer.WriteField("num_speakers", strconv.Itoa(*reqBody.NumSpeakers)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err) } } if reqBody.TimestampsGranularity != nil && *reqBody.TimestampsGranularity != "" { if err := writer.WriteField("timestamps_granularity", string(*reqBody.TimestampsGranularity)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err) } } if reqBody.Diarize != nil { if err := writer.WriteField("diarize", strconv.FormatBool(*reqBody.Diarize)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarize field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarize field", err) } } if reqBody.DiarizationThreshold != nil { if err := writer.WriteField("diarization_threshold", strconv.FormatFloat(*reqBody.DiarizationThreshold, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err) } } if len(reqBody.AdditionalFormats) > 0 { payload, err := providerUtils.MarshalSorted(reqBody.AdditionalFormats) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err) } if err := writer.WriteField("additional_formats", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err) } } if reqBody.FileFormat != nil && *reqBody.FileFormat != "" { if err := writer.WriteField("file_format", string(*reqBody.FileFormat)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file_format field", err) } } if reqBody.Webhook != nil { if err := writer.WriteField("webhook", strconv.FormatBool(*reqBody.Webhook)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook field", err) } } if reqBody.WebhookID != nil && strings.TrimSpace(*reqBody.WebhookID) != "" { if err := writer.WriteField("webhook_id", *reqBody.WebhookID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err) } } if reqBody.Temperature != nil { if err := writer.WriteField("temperature", strconv.FormatFloat(*reqBody.Temperature, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } if reqBody.Seed != nil { if err := writer.WriteField("seed", strconv.Itoa(*reqBody.Seed)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seed field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seed field", err) } } if reqBody.UseMultiChannel != nil { if err := writer.WriteField("use_multi_channel", strconv.FormatBool(*reqBody.UseMultiChannel)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err) } } @@ -727,16 +687,16 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr case string: if strings.TrimSpace(v) != "" { if err := writer.WriteField("webhook_metadata", v); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } default: payload, err := providerUtils.MarshalSorted(v) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err) } if err := writer.WriteField("webhook_metadata", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } } diff --git a/core/providers/elevenlabs/errors.go b/core/providers/elevenlabs/errors.go index 374e251958..f30807efd5 100644 --- a/core/providers/elevenlabs/errors.go +++ b/core/providers/elevenlabs/errors.go @@ -9,7 +9,7 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" ) -func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseElevenlabsError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp ElevenlabsError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Detail != nil { @@ -64,11 +64,6 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe Message: message, }, } - if meta != nil { - result.ExtraFields.Provider = meta.Provider - result.ExtraFields.ModelRequested = meta.Model - result.ExtraFields.RequestType = meta.RequestType - } return result } } @@ -91,10 +86,5 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe bifrostErr.Error.Message = message } } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index c211e85196..f762d97ee8 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -1,12 +1,13 @@ package elevenlabs import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,35 +16,36 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid Data: make([]schemas.Model, 0, len(*response)), } - includedModels := make(map[string]bool) - for _, model := range *response { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.ModelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ModelID, - Name: schemas.Ptr(model.Name), - }) - includedModels[model.ModelID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range *response { + for _, result := range pipeline.FilterModel(model.ModelID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/elevenlabs/realtime.go b/core/providers/elevenlabs/realtime.go index f124f58339..a18e1cd514 100644 --- a/core/providers/elevenlabs/realtime.go +++ b/core/providers/elevenlabs/realtime.go @@ -39,6 +39,44 @@ func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string] return headers } +// SupportsRealtimeWebRTC returns false β€” ElevenLabs WebRTC SDP exchange is not yet implemented. +func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool { + return false +} + +// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs. +func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) { + return "", &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"}, + } +} + +func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + return false +} + +func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string { + return "" +} + +func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string { + return "" +} + +func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + return eventType == schemas.RTEventResponseDone +} + // ElevenLabs Conversational AI WebSocket event types const ( elConversationInitMetadata = "conversation_initiation_metadata" @@ -50,8 +88,8 @@ const ( elInterruption = "interruption" elClientToolCall = "client_tool_call" - elUserAudioChunk = "user_audio_chunk" - elPong = "pong" + elUserAudioChunk = "user_audio_chunk" + elPong = "pong" elClientToolResult = "client_tool_result" elContextualUpdate = "contextual_update" ) @@ -134,7 +172,7 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra } case elAgentResponse: - event.Type = schemas.RTEventResponseTextDone + event.Type = schemas.RTEventResponseDone if raw.AgentResponse != nil { var agentResp elevenlabsTranscriptEvent if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil { @@ -194,10 +232,6 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra // ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON. func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - switch bifrostEvent.Type { case schemas.RTEventInputAudioAppend: if bifrostEvent.Delta == nil { diff --git a/core/providers/gemini/batch.go b/core/providers/gemini/batch.go index e3d92383f6..8f0405e524 100644 --- a/core/providers/gemini/batch.go +++ b/core/providers/gemini/batch.go @@ -249,8 +249,6 @@ func extractBatchIDFromName(name string) string { // downloadBatchResultsFile downloads and parses a batch results file from Gemini. // Returns the parsed result items from the JSONL file and any parse errors encountered. func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request to download the file req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -287,15 +285,12 @@ func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse JSONL content - each line is a separate JSON object diff --git a/core/providers/gemini/errors.go b/core/providers/gemini/errors.go index adf217a141..2d60a7bcd3 100644 --- a/core/providers/gemini/errors.go +++ b/core/providers/gemini/errors.go @@ -36,7 +36,7 @@ func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError { } // parseGeminiError parses Gemini error responses -func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse as []GeminiGenerationError var errorResps []GeminiGenerationError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps) @@ -62,11 +62,6 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada } // Set Message to trimmed concatenated message bifrostErr.Error.Message = message - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } @@ -80,10 +75,5 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code)) bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index d884294534..32d3c96f93 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -97,9 +97,7 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { // completeRequest handles the common HTTP request pattern for Gemini API calls. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, nil, latency, nil) β€” callers must check the context flag. -func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string, meta *providerUtils.RequestMetadata) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - +func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -146,10 +144,10 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp, meta) + return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, nil, latency, providerResponseHeaders, decodeErr } @@ -161,13 +159,13 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Parse Gemini's response var geminiResponse GenerateContentResponse if err := sonic.Unmarshal(body, &geminiResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var rawResponse interface{} if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } } @@ -208,10 +206,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseGeminiError(resp) } // Parse Gemini's response @@ -227,7 +222,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key } } - response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -282,24 +277,17 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - jsonData, err := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -312,9 +300,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -323,9 +308,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := geminiResponse.ToBifrostChatResponse() - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -363,8 +345,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion request is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -450,9 +431,9 @@ func HandleGeminiChatCompletionStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -462,11 +443,7 @@ func HandleGeminiChatCompletionStream( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) respBody := append([]byte(nil), resp.Body()...) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -483,9 +460,9 @@ func HandleGeminiChatCompletionStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -495,7 +472,6 @@ func HandleGeminiChatCompletionStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -557,7 +533,7 @@ func HandleGeminiChatCompletionStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } // Process chunk using shared function @@ -572,11 +548,6 @@ func HandleGeminiChatCompletionStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -597,11 +568,6 @@ func HandleGeminiChatCompletionStream( // Convert to Bifrost stream response response, bifrostErr, isLastChunk := geminiResponse.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -613,11 +579,8 @@ func HandleGeminiChatCompletionStream( response.Model = modelName } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -692,8 +655,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return nil, fmt.Errorf("responses input is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -705,11 +667,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem } // Use struct directly for JSON marshaling - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -722,9 +680,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -735,9 +690,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -767,13 +719,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( bodyReader io.Reader, // Optional: for large payload request streaming (pass nil for normal path) bodySize int, // Required if bodyReader is non-nil ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - meta := &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - } - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -807,14 +752,14 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Handle error response β€” materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - bifrostErr := parseGeminiError(resp, meta) + bifrostErr := parseGeminiError(resp) wait() fasthttp.ReleaseResponse(resp) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Delegate large response detection + normal buffered path to shared utility - responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -830,9 +775,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( Model: request.Model, Usage: usage, } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // resp owned by reader in context β€” don't release wait() @@ -844,12 +786,9 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Normal parse-and-convert path var geminiResponse GenerateContentResponse if unmarshalErr := sonic.Unmarshal(responseBody, &geminiResponse); unmarshalErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr) } bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -901,8 +840,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, fmt.Errorf("responses input is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -987,9 +925,9 @@ func HandleGeminiResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -998,11 +936,7 @@ func HandleGeminiResponsesStream( // Check for HTTP errors β€” use parseGeminiError to preserve upstream error details if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1019,9 +953,9 @@ func HandleGeminiResponsesStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1032,7 +966,6 @@ func HandleGeminiResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError( @@ -1101,7 +1034,7 @@ func HandleGeminiResponsesStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } @@ -1117,11 +1050,6 @@ func HandleGeminiResponsesStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1139,11 +1067,6 @@ func HandleGeminiResponsesStream( // Convert to Bifrost responses stream response responses, bifrostErr := geminiResponse.ToBifrostResponsesStream(sequenceNumber, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1152,11 +1075,8 @@ func HandleGeminiResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1209,11 +1129,8 @@ func HandleGeminiResponsesStream( continue } finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1258,8 +1175,7 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiEmbeddingRequest(request), nil - }, - providerName) + }) if err != nil { return nil, err } @@ -1310,17 +1226,13 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) wait() fasthttp.ReleaseResponse(resp) return nil, parsedErr } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -1333,9 +1245,6 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1357,12 +1266,9 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := ToBifrostEmbeddingResponse(&geminiResponse, request.Model) if bifrostResponse == nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, - fmt.Errorf("failed to convert Gemini embedding response to Bifrost format"), providerName) + fmt.Errorf("failed to convert Gemini embedding response to Bifrost format")) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1391,18 +1297,13 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.SpeechRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1414,9 +1315,6 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1428,13 +1326,10 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. } response, convErr := geminiResponse.ToBifrostSpeechResponse(ctx) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.SpeechRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1461,16 +1356,13 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using speech-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1516,9 +1408,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1527,11 +1419,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1550,9 +1438,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1592,7 +1480,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1612,11 +1500,6 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1667,11 +1550,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -1688,11 +1568,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDone, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } response.BackfillParams(request) @@ -1720,18 +1597,13 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1743,9 +1615,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostTranscriptionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TranscriptionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1755,9 +1624,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s response := geminiResponse.ToBifrostTranscriptionResponse() // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TranscriptionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1779,16 +1645,13 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using transcription-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1834,9 +1697,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1845,11 +1708,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1868,9 +1727,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1910,7 +1769,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1929,11 +1788,6 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1978,11 +1832,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Type: schemas.TranscriptionStreamResponseTypeDelta, Delta: &deltaText, // Delta text for this chunk ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -2005,11 +1856,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, TotalTokens: usage.TotalTokens, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2042,18 +1890,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageGenerationRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2065,9 +1908,6 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2076,25 +1916,16 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image generation response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - provider.GetProviderKey(), ) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2111,16 +1942,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key // handleImagenImageGeneration handles Imagen model requests using Vertex AI endpoint with API key auth func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Prepare Imagen request body jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2160,16 +1988,11 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Imagen response - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2177,10 +2000,7 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2192,9 +2012,6 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost } // Convert to Bifrost format response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2219,8 +2036,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, err } - providerName := provider.GetProviderKey() - // Handle Imagen models using :predict endpoint if schemas.IsImagenModel(request.Model) { jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2228,8 +2043,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2264,15 +2078,10 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2280,10 +2089,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem imagenRespOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2295,9 +2101,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2316,18 +2119,13 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2339,9 +2137,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2350,25 +2145,16 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image edit response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - providerName, ) } // Set ExtraFields - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2400,7 +2186,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() model := bifrostReq.Model jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2409,7 +2194,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2442,17 +2226,13 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // use handle provider response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2468,12 +2248,9 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = model - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -2491,10 +2268,9 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s return nil, err } - providerName := provider.GetProviderKey() operationID := bifrostReq.ID - operationID = providerUtils.StripVideoIDProviderSuffix(operationID, providerName) + operationID = providerUtils.StripVideoIDProviderSuffix(operationID, provider.GetProviderKey()) // Create HTTP request req := fasthttp.AcquireRequest() @@ -2519,10 +2295,8 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + respBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2536,12 +2310,10 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) // Add extra fields bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { bifrostResp.ExtraFields.RawResponse = rawResponse @@ -2555,9 +2327,8 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first so download behavior follows retrieve status. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2572,11 +2343,10 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), nil, - providerName, ) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte contentType := "video/mp4" @@ -2587,7 +2357,7 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded latency = time.Since(startTime) @@ -2618,17 +2388,16 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), nil, - providerName, ) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ VideoID: request.ID, @@ -2637,8 +2406,6 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -2668,18 +2435,16 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - // Validate that either InputFileID or Requests is provided, but not both hasFileInput := request.InputFileID != "" hasInlineRequests := len(request.Requests) > 0 if !hasFileInput && !hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil) } if hasFileInput && hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil) } // Build the batch request with proper nested structure @@ -2712,12 +2477,12 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch if rawMessages, ok := body["messages"]; ok { messagesBytes, err := providerUtils.MarshalSorted(rawMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err) } var chatMessages []schemas.ChatMessage err = sonic.Unmarshal(messagesBytes, &chatMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err) } contents, systemInstruction := convertBifrostMessagesToGemini(chatMessages) @@ -2727,11 +2492,11 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // If no "messages" key, try direct unmarshal (already in Gemini format) requestBytes, err := providerUtils.MarshalSorted(body) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err) } err = sonic.Unmarshal(requestBytes, &geminiReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err) } } @@ -2755,7 +2520,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(batchReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create HTTP request @@ -2793,31 +2558,27 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.BatchCreateRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse the batch job response var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { provider.logger.Error("gemini batch create unmarshal error: " + err.Error()) - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Check for metadata if geminiResp.Metadata == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil) } // Check for batch stats if geminiResp.Metadata.BatchStats == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil) } // Calculate request counts based on response totalRequests := geminiResp.Metadata.BatchStats.RequestCount @@ -2860,9 +2621,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2881,8 +2640,6 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // batchListByKey lists batch jobs for Gemini for a single key. func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2930,26 +2687,21 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format @@ -2961,10 +2713,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Status: ToBifrostBatchStatus(batch.Metadata.State), CreatedAt: parseGeminiTimestamp(batch.Metadata.CreateTime), OperationName: &batch.Name, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, + ExtraFields: schemas.BifrostResponseExtraFields{}, }) } @@ -2980,9 +2729,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key HasMore: hasMore, NextCursor: nextCursor, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } @@ -2996,16 +2743,14 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil) } // Initialize serial pagination helper (Gemini uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3016,10 +2761,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -3051,9 +2792,7 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3065,8 +2804,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // batchRetrieveByKey retrieves a specific batch job for Gemini for a single key. func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3099,20 +2836,17 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var completedCount, failedCount int @@ -3141,9 +2875,7 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3154,14 +2886,12 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil) } // Try each key until we find the batch @@ -3180,8 +2910,6 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // batchCancelByKey cancels a batch job for Gemini for a single key. func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3219,15 +2947,9 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke if resp.StatusCode() == fasthttp.StatusNotFound || resp.StatusCode() == fasthttp.StatusMethodNotAllowed { // 404 could mean batch not found or cancel not supported // Return the error instead of assuming completed - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } now := time.Now().Unix() @@ -3237,9 +2959,7 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke Status: schemas.BatchStatusCancelling, CancellingAt: &now, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3251,14 +2971,12 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil) } // Try each key until cancellation succeeds @@ -3280,8 +2998,6 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // batches.delete indicates the client is no longer interested in the operation result. // It does not cancel the operation. If the server doesn't support this method, it returns UNIMPLEMENTED. func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) @@ -3310,10 +3026,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke } if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostBatchDeleteResponse{ @@ -3321,9 +3034,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke Object: "batch", Status: schemas.BatchStatusDeleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3336,14 +3047,12 @@ func (provider *GeminiProvider) BatchDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil) } var lastError *schemas.BifrostError @@ -3491,8 +3200,6 @@ func readNextSSEDataLine(reader *bufio.Reader, skipInlineData bool) ([]byte, err // batchResultsByKey retrieves batch results for Gemini for a single key. func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // We need to get the full batch response with results, so make the API call directly req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3526,20 +3233,17 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Check if batch is still processing @@ -3547,7 +3251,6 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("batch %s is still processing (state: %s), results not yet available", request.BatchID, geminiResp.Metadata.State), nil, - providerName, ) } @@ -3635,9 +3338,7 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3656,14 +3357,12 @@ func (provider *GeminiProvider) BatchResults(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil) } // Try each key until we get results @@ -3687,10 +3386,8 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart request @@ -3700,14 +3397,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add file metadata as JSON metadataField, err := writer.CreateFormField("metadata") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err) } metadataJSON, err := providerUtils.SetJSONField([]byte(`{}`), "file.displayName", request.Filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } if _, err := metadataField.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } // Add file content @@ -3717,14 +3414,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3755,15 +3452,12 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileUploadRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse response - wrapped in "file" object @@ -3771,7 +3465,7 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche File GeminiFileResponse `json:"file"` } if err := sonic.Unmarshal(body, &responseWrapper); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } geminiResp := responseWrapper.File @@ -3807,17 +3501,13 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } // fileListByKey lists files from Gemini for a single key. func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3854,20 +3544,17 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost response @@ -3876,9 +3563,7 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s Data: make([]schemas.FileObject, len(geminiResp.Files)), HasMore: geminiResp.NextPageToken != "", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3931,16 +3616,14 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil) } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3951,10 +3634,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3986,9 +3665,7 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -4000,8 +3677,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch // fileRetrieveByKey retrieves file metadata from Gemini for a single key. func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4032,20 +3707,17 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var sizeBytes int64 @@ -4082,9 +3754,7 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4095,14 +3765,12 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil) } // Try each key until we find the file @@ -4122,8 +3790,6 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // fileDeleteByKey deletes a file from Gemini for a single key. func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4154,10 +3820,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key // Handle error response - DELETE returns 200 with empty body on success if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostFileDeleteResponse{ @@ -4165,9 +3828,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4178,14 +3839,12 @@ func (provider *GeminiProvider) FileDelete(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil) } // Try each key until deletion succeeds @@ -4211,14 +3870,11 @@ func (provider *GeminiProvider) FileContent(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - // Gemini doesn't support direct file content download // Files are referenced by their URI in requests return nil, providerUtils.NewBifrostOperationError( "Gemini Files API doesn't support direct content download. Use the file URI in your requests instead.", nil, - providerName, ) } @@ -4249,7 +3905,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiResponsesRequest(request) }, - provider.GetProviderKey(), ) if bifrostErr != nil { return nil, bifrostErr @@ -4261,14 +3916,13 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonData, _ = providerUtils.DeleteJSONField(jsonData, "systemInstruction") } - providerName := provider.GetProviderKey() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) if strings.TrimSpace(request.Model) == "" { - return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model"), providerName) + return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model")) } // Determine native model name (e.g., parse any provider prefix) @@ -4301,15 +3955,12 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch } if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } responseBody := append([]byte(nil), body...) @@ -4329,9 +3980,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch response := geminiResponse.ToBifrostCountTokensResponse(request.Model) // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.CountTokensRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -4434,7 +4082,7 @@ func (provider *GeminiProvider) Passthrough( headers := providerUtils.ExtractProviderResponseHeaders(resp) body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -4448,9 +4096,6 @@ func (provider *GeminiProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4514,9 +4159,9 @@ func (provider *GeminiProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -4527,7 +4172,6 @@ func (provider *GeminiProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -4539,11 +4183,7 @@ func (provider *GeminiProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4554,9 +4194,9 @@ func (provider *GeminiProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -4605,7 +4245,7 @@ func (provider *GeminiProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/gemini/images.go b/core/providers/gemini/images.go index b390537b3e..881988392a 100644 --- a/core/providers/gemini/images.go +++ b/core/providers/gemini/images.go @@ -408,10 +408,10 @@ func ToGeminiImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRe // Handle size conversion if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" { - imageSize, aspectRatio := convertSizeToImagenFormat(*bifrostReq.Params.Size) + aspectRatio, imageSize := utils.ConvertSizeToAspectRatioAndResolution(*bifrostReq.Params.Size) if imageSize != "" && aspectRatio != "" { geminiReq.GenerationConfig.ImageConfig = &GeminiImageConfig{ - ImageSize: imageSize, + ImageSize: strings.ToLower(imageSize), AspectRatio: aspectRatio, } } @@ -513,9 +513,10 @@ func ToImagenImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRe // Handle size conversion if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" { - imageSize, aspectRatio := convertSizeToImagenFormat(*bifrostReq.Params.Size) + aspectRatio, imageSize := utils.ConvertSizeToAspectRatioAndResolution(*bifrostReq.Params.Size) if imageSize != "" { - req.Parameters.SampleImageSize = &imageSize + imageSizeLower := strings.ToLower(imageSize) + req.Parameters.SampleImageSize = &imageSizeLower } if aspectRatio != "" { req.Parameters.AspectRatio = &aspectRatio @@ -638,55 +639,6 @@ func convertOutputFormatToMimeType(outputFormat string) string { } } -// convertSizeToImagenFormat converts standard size format (e.g., "1024x1024") to Imagen format -// Returns (imageSize, aspectRatio) where imageSize is "1k", "2k", "4k" and aspectRatio is one of: -// "1:1", "3:4", "4:3", "9:16", or "16:9" -func convertSizeToImagenFormat(size string) (string, string) { - // Parse size string (format: "WIDTHxHEIGHT") - parts := strings.Split(size, "x") - if len(parts) != 2 { - return "", "" - } - - width, err1 := strconv.Atoi(parts[0]) - height, err2 := strconv.Atoi(parts[1]) - if err1 != nil || err2 != nil { - return "", "" - } - - // Validate width and height are positive integers - if width <= 0 || height <= 0 { - return "", "" - } - - var imageSize string - if width <= 1024 && height <= 1024 { - imageSize = "1k" - } else if width <= 2048 && height <= 2048 { - imageSize = "2k" - } else if width <= 4096 && height <= 4096 { - imageSize = "4k" - } - - // Calculate aspect ratio - var aspectRatio string - ratio := float64(width) / float64(height) - - // Common aspect ratios with tolerance - if ratio >= 0.99 && ratio <= 1.01 { - aspectRatio = "1:1" - } else if ratio >= 0.74 && ratio <= 0.76 { - aspectRatio = "3:4" - } else if ratio >= 1.32 && ratio <= 1.34 { - aspectRatio = "4:3" - } else if ratio >= 0.56 && ratio <= 0.57 { - aspectRatio = "9:16" - } else if ratio >= 1.77 && ratio <= 1.78 { - aspectRatio = "16:9" - } - - return imageSize, aspectRatio -} // ToBifrostImageGenerationResponse converts an Imagen response to Bifrost format func (response *GeminiImagenResponse) ToBifrostImageGenerationResponse() *schemas.BifrostImageGenerationResponse { diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index 4c8f83c364..7b9f6410eb 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -1,9 +1,9 @@ package gemini import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -17,7 +17,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -26,45 +26,47 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - includedModels := make(map[string]bool) - for _, model := range response.Models { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + included := make(map[string]bool) + + for _, model := range response.Models { contextLength := model.InputTokenLimit + model.OutputTokenLimit - // Remove prefix models/ from model.Name + // Gemini returns model names with a "models/" prefix β€” strip it before filtering + // so that allowedModels entries like "gemini-1.5-pro" match correctly. modelName := strings.TrimPrefix(model.Name, "models/") - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelName) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, modelName) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelName, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - ContextLength: schemas.Ptr(int(contextLength)), - MaxInputTokens: schemas.Ptr(model.InputTokenLimit), - MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), - SupportedMethods: model.SupportedGenerationMethods, - }) - includedModels[modelName] = true - } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + for _, result := range pipeline.FilterModel(modelName) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + ContextLength: schemas.Ptr(int(contextLength)), + MaxInputTokens: schemas.Ptr(model.InputTokenLimit), + MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), + SupportedMethods: model.SupportedGenerationMethods, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/gemini/videos.go b/core/providers/gemini/videos.go index 62ce110c26..43ece90be4 100644 --- a/core/providers/gemini/videos.go +++ b/core/providers/gemini/videos.go @@ -217,7 +217,7 @@ func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if operation == nil { - return nil, providerUtils.NewBifrostOperationError("operation is nil", nil, schemas.Gemini) + return nil, providerUtils.NewBifrostOperationError("operation is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index 4bbcfd1395..b3c030b386 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -149,9 +149,6 @@ func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/groq/groq_test.go b/core/providers/groq/groq_test.go index 54b7945d76..3bf59588db 100644 --- a/core/providers/groq/groq_test.go +++ b/core/providers/groq/groq_test.go @@ -38,28 +38,28 @@ func TestGroq(t *testing.T) { TranscriptionModel: "whisper-large-v3", SpeechSynthesisModel: "canopylabs/orpheus-v1-english", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, - TextCompletionStream: false, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, + TextCompletionStream: false, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - FileBase64: false, // Not supported - FileURL: false, // Not supported - CompleteEnd2End: true, - Embedding: false, - ListModels: true, - Reasoning: true, - Transcription: true, - SpeechSynthesis: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + FileBase64: false, // Not supported + FileURL: false, // Not supported + CompleteEnd2End: true, + Embedding: false, + ListModels: true, + Reasoning: true, + Transcription: true, + SpeechSynthesis: true, }, } t.Run("GroqTests", func(t *testing.T) { diff --git a/core/providers/huggingface/errors.go b/core/providers/huggingface/errors.go index 49ce427df7..d98357e0a8 100644 --- a/core/providers/huggingface/errors.go +++ b/core/providers/huggingface/errors.go @@ -10,7 +10,7 @@ import ( ) // parseHuggingFaceImageError parses HuggingFace error responses -func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseHuggingFaceImageError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp HuggingFaceResponseError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -53,13 +53,5 @@ func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.Req bifrostErr.Error.Message = errorResp.Error } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index f2fa8a4547..32dadecc61 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -254,12 +254,12 @@ func (provider *HuggingFaceProvider) completeRequest(ctx *schemas.BifrostContext // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp, nil) + return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -325,7 +325,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName)} + resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)} return } @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext } if result.response != nil { - providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered) + providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerResponse != nil { aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...) totalLatency += result.latency @@ -459,10 +459,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionRequest, - }, } } if inferenceProvider != "" { @@ -483,8 +479,7 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, reqBody.Stream = schemas.Ptr(false) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -518,9 +513,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse.Object = "chat.completion" } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -550,10 +542,6 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostCo Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionStreamRequest, - }, } } if inferenceProvider != "" { @@ -610,9 +598,6 @@ func (provider *HuggingFaceProvider) Responses(ctx *schemas.BifrostContext, key } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -644,10 +629,6 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.EmbeddingRequest, - }, } } @@ -657,8 +638,7 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceEmbeddingRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -698,13 +678,10 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key // Unmarshal directly to BifrostEmbeddingResponse with custom logic bifrostResponse, convErr := UnmarshalHuggingFaceEmbeddingResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -735,10 +712,6 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.SpeechRequest, - }, } } @@ -747,8 +720,7 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -784,18 +756,15 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch // Download the audio file from the URL audioData, downloadErr := provider.downloadAudioFromURL(ctx, response.Audio.URL) if downloadErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse, convErr := response.ToBifrostSpeechResponse(request.Model, audioData) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.SpeechRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -833,10 +802,6 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.TranscriptionRequest, - }, } } @@ -846,7 +811,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, isHFInferenceAudioRequest := inferenceProvider == hfInference if inferenceProvider == hfInference { if request.Input == nil || len(request.Input.File) == 0 { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests"), provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests")) } jsonData = request.Input.File } else { @@ -856,8 +821,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceTranscriptionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -905,13 +869,10 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, bifrostResponse, convErr := response.ToBifrostTranscriptionResponse(request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TranscriptionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -945,10 +906,6 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationRequest, - }, } } @@ -958,8 +915,7 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageGenerationRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -999,15 +955,12 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext // Unmarshal response using Nebius converter bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1039,10 +992,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationStreamRequest, - }, } } @@ -1050,11 +999,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image generation streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } - providerName := provider.GetProviderKey() // Set headers headers := map[string]string{ @@ -1072,8 +1018,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageStreamRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1105,9 +1050,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1123,9 +1065,9 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1134,11 +1076,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1161,9 +1099,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1184,6 +1120,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1202,14 +1140,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1232,11 +1163,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1262,11 +1188,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageGenerationEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1306,11 +1229,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Type: schemas.ImageGenerationEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ @@ -1354,10 +1274,6 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditRequest, - }, } } @@ -1372,8 +1288,7 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageEditRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -1409,15 +1324,12 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key // Unmarshal response bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1449,10 +1361,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditStreamRequest, - }, } } @@ -1460,9 +1368,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image edit streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } var authHeader map[string]string @@ -1488,15 +1394,13 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageEditRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1524,9 +1428,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1542,9 +1443,9 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1553,11 +1454,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1580,9 +1477,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1603,6 +1498,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1621,14 +1518,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1651,11 +1541,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1681,11 +1566,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageEditEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1725,11 +1607,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Type: schemas.ImageEditEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index c637c3b6f6..de615ccec2 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -5,6 +5,7 @@ import ( "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -13,7 +14,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,15 +23,20 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } - var blacklisted map[string]struct{} - if !unfiltered && len(blacklistedModels) > 0 { - blacklisted = make(map[string]struct{}, len(blacklistedModels)) - for _, m := range blacklistedModels { - blacklisted[m] = struct{}{} - } + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - includedModels := make(map[string]bool) + included := make(map[string]bool) + for _, model := range response.Models { if model.ModelID == "" { continue @@ -41,39 +47,33 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi continue } - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { - continue - } - if _, ok := blacklisted[model.ModelID]; ok { - continue - } - - newModel := schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, model.ModelID), - Name: schemas.Ptr(model.ModelID), - SupportedMethods: supported, - HuggingFaceID: schemas.Ptr(model.ID), - } - - bifrostResponse.Data = append(bifrostResponse.Data, newModel) - includedModels[model.ModelID] = true - } - - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if _, ok := blacklisted[allowedModel]; ok { - continue + // Aliases apply at the model level (model.ModelID), not at the compound + // "{providerKey}/{inferenceProvider}/{modelID}" level. + for _, result := range pipeline.FilterModel(model.ModelID) { + newModel := schemas.Model{ + // inferenceProvider stays in the compound ID; aliases rename only the model segment + ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID), + Name: schemas.Ptr(model.ModelID), + SupportedMethods: supported, + HuggingFaceID: schemas.Ptr(model.ID), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + newModel.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, newModel) + included[strings.ToLower(result.ResolvedID)] = true } } + // Backfill: use standard pipeline. Note that backfilled HF entries use a simplified + // compound ID since we don't know which inferenceProvider to assign them to. + for _, m := range pipeline.BackfillModels(included) { + // Re-wrap the backfill ID to include the inferenceProvider segment + rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/") + m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID) + bifrostResponse.Data = append(bifrostResponse.Data, m) + } + return bifrostResponse } diff --git a/core/providers/huggingface/responses.go b/core/providers/huggingface/responses.go index fd68aa76a8..35ad2c336d 100644 --- a/core/providers/huggingface/responses.go +++ b/core/providers/huggingface/responses.go @@ -43,9 +43,6 @@ func ToBifrostResponsesResponseFromHuggingFace(resp *schemas.BifrostChatResponse responsesResp := resp.ToBifrostResponsesResponse() if responsesResp != nil { - responsesResp.ExtraFields.Provider = schemas.HuggingFace - responsesResp.ExtraFields.ModelRequested = requestedModel - responsesResp.ExtraFields.RequestType = schemas.ResponsesRequest } return responsesResp, nil diff --git a/core/providers/huggingface/speech.go b/core/providers/huggingface/speech.go index 65c0ba6e12..f702d1f39f 100644 --- a/core/providers/huggingface/speech.go +++ b/core/providers/huggingface/speech.go @@ -125,10 +125,6 @@ func (response *HuggingFaceSpeechResponse) ToBifrostSpeechResponse(requestedMode // Create the base Bifrost response with the downloaded audio data bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: audioData, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Note: HuggingFace TTS API typically doesn't return usage information diff --git a/core/providers/huggingface/transcription.go b/core/providers/huggingface/transcription.go index 0d892cb07c..f3ff5c293a 100644 --- a/core/providers/huggingface/transcription.go +++ b/core/providers/huggingface/transcription.go @@ -144,10 +144,6 @@ func (response *HuggingFaceTranscriptionResponse) ToBifrostTranscriptionResponse // Create the base Bifrost response bifrostResponse := &schemas.BifrostTranscriptionResponse{ Text: response.Text, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Map chunks to segments if available diff --git a/core/providers/huggingface/utils.go b/core/providers/huggingface/utils.go index ad68e4c6ac..b96210c832 100644 --- a/core/providers/huggingface/utils.go +++ b/core/providers/huggingface/utils.go @@ -221,8 +221,6 @@ func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappin } func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Check cache first if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok { if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok { @@ -259,12 +257,12 @@ func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx contex body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var mappingResp HuggingFaceInferenceProviderMappingResponse if err := sonic.Unmarshal(body, &mappingResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } mappings := convertToInferenceProviderMappings(&mappingResp) diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 4469c4abc3..fb11cfbc88 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -74,8 +74,6 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -101,7 +99,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := openai.ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := openai.ParseOpenAIError(resp) return nil, bifrostErr } @@ -116,7 +114,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Create final response - response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels) + response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -212,9 +210,6 @@ func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key sche } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -273,7 +268,7 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form body @@ -310,12 +305,12 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -343,20 +338,17 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostTranscriptionResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.TranscriptionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -378,7 +370,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } mistralReq.Stream = schemas.Ptr(true) @@ -433,9 +425,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -445,7 +437,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -464,9 +456,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -505,7 +497,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) } break } @@ -553,11 +545,6 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -586,11 +573,8 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( // Set extra fields response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(*lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(*lastChunkTime).Milliseconds(), } *lastChunkTime = time.Now() diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index ef3e5934c1..8d5fd7f3d6 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -1,12 +1,13 @@ package mistral import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,40 +16,40 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { - continue - } - if slices.Contains(blacklistedModels, model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + model.ID, - Name: schemas.Ptr(model.Name), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.Created), - ContextLength: schemas.Ptr(int(model.MaxContextLength)), - OwnedBy: schemas.Ptr(model.OwnedBy), - }) - includedModels[model.ID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Mistral, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Mistral) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.Created), + ContextLength: schemas.Ptr(int(model.MaxContextLength)), + OwnedBy: schemas.Ptr(model.OwnedBy), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - includedModels[allowedModel] = true + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/mistral/transcription.go b/core/providers/mistral/transcription.go index a4a018e5c6..fe9b262126 100644 --- a/core/providers/mistral/transcription.go +++ b/core/providers/mistral/transcription.go @@ -109,58 +109,58 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, req *Mi } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(req.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } // Add model field (required) if err := writer.WriteField("model", req.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add stream field if streaming if req.Stream != nil && *req.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } // Add optional fields if req.Language != nil { if err := writer.WriteField("language", *req.Language); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language field", err) } } if req.Prompt != nil { if err := writer.WriteField("prompt", *req.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } } if req.ResponseFormat != nil { if err := writer.WriteField("response_format", *req.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if req.Temperature != nil { if err := writer.WriteField("temperature", formatFloat64(*req.Temperature)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range req.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } // Close the multipart writer to finalize the form if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/nebius/errors.go b/core/providers/nebius/errors.go index 98d0fb78d8..de8bcf0d84 100644 --- a/core/providers/nebius/errors.go +++ b/core/providers/nebius/errors.go @@ -9,7 +9,7 @@ import ( ) // parseNebiusImageError parses Nebius error responses -func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseNebiusImageError(resp *fasthttp.Response) *schemas.BifrostError { var nebiusErr NebiusError bifrostErr := providerUtils.HandleProviderAPIError(resp, &nebiusErr) @@ -60,13 +60,5 @@ func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestM bifrostErr.Error.Message = message } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index d8bd8a2256..42429b87ab 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -193,9 +193,6 @@ func (provider *NebiusProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -260,16 +257,15 @@ func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Validate request is not nil if request == nil { - return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil) } // Validate input and prompt are not nil/empty if request.Input == nil || strings.TrimSpace(request.Input.Prompt) == "" { - return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil) } path := providerUtils.GetPathFromContext(ctx, "/v1/images/generations") - providerName := schemas.Nebius // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -304,8 +300,7 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return provider.ToNebiusImageGenerationRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -323,16 +318,12 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp), jsonData, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } response := &schemas.BifrostImageGenerationResponse{} @@ -352,9 +343,6 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled diff --git a/core/providers/nebius/nebius_test.go b/core/providers/nebius/nebius_test.go index 898617cb1e..da6c4065e8 100644 --- a/core/providers/nebius/nebius_test.go +++ b/core/providers/nebius/nebius_test.go @@ -32,25 +32,25 @@ func TestNebius(t *testing.T) { EmbeddingModel: "BAAI/bge-en-icl", ImageGenerationModel: "black-forest-labs/flux-schnell", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - ImageGeneration: true, - CompleteEnd2End: true, - ImageGenerationStream: false, - Embedding: true, // Nebius supports embeddings - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + ImageGeneration: true, + CompleteEnd2End: true, + ImageGenerationStream: false, + Embedding: true, // Nebius supports embeddings + ListModels: true, }, } diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index a63b05d0c1..82b9471a41 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -3,7 +3,6 @@ package ollama import ( - "fmt" "strings" "time" @@ -50,11 +49,7 @@ func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for Ollama - if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for ollama provider") - } - + // BaseURL is optional when keys have ollama_key_config with per-key URLs return &OllamaProvider{ logger: logger, client: client, @@ -69,17 +64,14 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } -// ListModels performs a list models request to Ollama's API. -func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) - } - return openai.HandleOpenAIListModelsRequest( +// listModelsByKey performs a list models request for a single Ollama key. +func (provider *OllamaProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return openai.ListModelsByKey( ctx, provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), + key, + request.Unfiltered, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -87,12 +79,24 @@ func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []s ) } +// ListModels performs a list models request to Ollama's API. +// Requests are made concurrently per key so that each backend is queried +// with its own URL (from ollama_key_config). +func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return providerUtils.HandleMultipleListModelsRequests( + ctx, + keys, + request, + provider.listModelsByKey, + ) +} + // TextCompletion performs a text completion request to the Ollama API. func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -112,7 +116,7 @@ func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/completions", + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -132,7 +136,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -154,7 +158,7 @@ func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -179,9 +183,6 @@ func (provider *OllamaProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -202,7 +203,7 @@ func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schem return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/ollama/ollama_test.go b/core/providers/ollama/ollama_test.go index ad31297046..a9005c1220 100644 --- a/core/providers/ollama/ollama_test.go +++ b/core/providers/ollama/ollama_test.go @@ -29,24 +29,24 @@ func TestOllama(t *testing.T) { TextModel: "", // Ollama doesn't support text completion in newer models EmbeddingModel: "", // Ollama doesn't support embedding Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - FileBase64: false, - FileURL: false, - CompleteEnd2End: true, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + FileBase64: false, + FileURL: false, + CompleteEnd2End: true, + Embedding: false, + ListModels: true, }, } diff --git a/core/providers/openai/batch.go b/core/providers/openai/batch.go index ae095e5c77..ec8ce468bb 100644 --- a/core/providers/openai/batch.go +++ b/core/providers/openai/batch.go @@ -10,10 +10,10 @@ import ( // OpenAIBatchRequest represents the request body for creating a batch. type OpenAIBatchRequest struct { - InputFileID string `json:"input_file_id"` - Endpoint string `json:"endpoint"` - CompletionWindow string `json:"completion_window"` - Metadata map[string]string `json:"metadata,omitempty"` + InputFileID string `json:"input_file_id"` + Endpoint string `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]string `json:"metadata,omitempty"` OutputExpiresAfter *schemas.BatchExpiresAfter `json:"output_expires_after,omitempty"` } @@ -82,7 +82,7 @@ func ToBifrostBatchStatus(status string) schemas.BatchStatus { } // ToBifrostBatchCreateResponse converts OpenAI batch response to Bifrost batch response. -func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, Object: r.Object, @@ -95,9 +95,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. OutputFileID: r.OutputFileID, ErrorFileID: r.ErrorFileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -125,7 +123,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. } // ToBifrostBatchRetrieveResponse converts OpenAI batch response to Bifrost batch retrieve response. -func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: r.Object, @@ -146,9 +144,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema ErrorFileID: r.ErrorFileID, Errors: r.Errors, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -174,35 +170,3 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema return resp } - -// splitJSONL splits JSONL content into individual lines. -func splitJSONL(data []byte) [][]byte { - var lines [][]byte - start := 0 - for i, b := range data { - if b == '\n' { - if i > start { - end := i - // Strip trailing \r if present (handle CRLF) - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - start = i + 1 - } - } - if start < len(data) { - end := len(data) - // Strip trailing \r if present - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - return lines -} diff --git a/core/providers/openai/chat_test.go b/core/providers/openai/chat_test.go index f391f821cb..f5e08c7f8e 100644 --- a/core/providers/openai/chat_test.go +++ b/core/providers/openai/chat_test.go @@ -277,7 +277,6 @@ func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *tes func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/openai/errors.go b/core/providers/openai/errors.go index 6a5bc1ce08..69d0aff407 100644 --- a/core/providers/openai/errors.go +++ b/core/providers/openai/errors.go @@ -10,10 +10,10 @@ import ( ) // ErrorConverter is a function that converts provider-specific error responses to BifrostError. -type ErrorConverter func(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError +type ErrorConverter func(resp *fasthttp.Response) *schemas.BifrostError // ParseOpenAIError parses OpenAI error responses. -func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp schemas.BifrostError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -49,11 +49,6 @@ func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, } // Set ExtraFields unconditionally so provider/model/request metadata is always attached - if bifrostErr != nil { - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - } return bifrostErr } diff --git a/core/providers/openai/errors_test.go b/core/providers/openai/errors_test.go index f33008600b..1132a92723 100644 --- a/core/providers/openai/errors_test.go +++ b/core/providers/openai/errors_test.go @@ -12,7 +12,7 @@ func TestParseOpenAIError_FallbackMessageWhenProviderBodyIsNonOpenAIShape(t *tes resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"detail":[{"loc":["body","messages",0,"role"],"msg":"value is not a valid enumeration member"}]}`) - errResp := ParseOpenAIError(&resp, schemas.ResponsesStreamRequest, schemas.Cerebras, "llama3.1-8b") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -29,7 +29,7 @@ func TestParseOpenAIError_PreservesProviderMessageWhenPresent(t *testing.T) { resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"error":{"message":"unsupported role: developer","type":"invalid_request_error","param":"messages.0.role","code":"invalid_value"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -43,7 +43,7 @@ func TestParseOpenAIError_FallbackMessageWhenBodyIsEmpty(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBody(nil) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -59,7 +59,7 @@ func TestParseOpenAIError_WhitespaceProviderMessageFallsBack(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBodyString(`{"error":{"message":" ","type":"invalid_request_error"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -73,7 +73,7 @@ func TestParseOpenAIError_DefaultStatusCodeFallsBackWithStatusNumber(t *testing. // fasthttp defaults zero-value response status code to 200. resp.SetBodyString(`{"error":{"message":""}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } diff --git a/core/providers/openai/files.go b/core/providers/openai/files.go index bbaf2b2f70..133250cac7 100644 --- a/core/providers/openai/files.go +++ b/core/providers/openai/files.go @@ -55,7 +55,7 @@ func ToBifrostFileStatus(status string) schemas.FileStatus { } // ToBifrostFileUploadResponse converts OpenAI file response to Bifrost file upload response. -func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Object, @@ -67,9 +67,7 @@ func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mo StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -97,9 +95,7 @@ func (r *OpenAIFileResponse) ToBifrostFileRetrieveResponse(providerName schemas. StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/openai/images.go b/core/providers/openai/images.go index f183e17ec5..9176f1e1e7 100644 --- a/core/providers/openai/images.go +++ b/core/providers/openai/images.go @@ -125,18 +125,18 @@ func ToOpenAIImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) *Open func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageEditRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add prompt field (required) if err := writer.WriteField("prompt", openaiReq.Input.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add stream field when requesting streaming if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } @@ -168,71 +168,71 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err) } if _, err := part.Write(imageInput.Image); err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err) } } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Quality != nil { if err := writer.WriteField("quality", *openaiReq.Quality); err != nil { - return providerUtils.NewBifrostOperationError("failed to write quality field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write quality field", err) } } if openaiReq.Background != nil { if err := writer.WriteField("background", *openaiReq.Background); err != nil { - return providerUtils.NewBifrostOperationError("failed to write background field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write background field", err) } } if openaiReq.InputFidelity != nil { if err := writer.WriteField("input_fidelity", *openaiReq.InputFidelity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err) } } if openaiReq.PartialImages != nil { if err := writer.WriteField("partial_images", strconv.Itoa(*openaiReq.PartialImages)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write partial_images field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write partial_images field", err) } } if openaiReq.OutputFormat != nil { if err := writer.WriteField("output_format", *openaiReq.OutputFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_format field", err) } } if openaiReq.OutputCompression != nil { if err := writer.WriteField("output_compression", strconv.Itoa(*openaiReq.OutputCompression)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_compression field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_compression field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } @@ -260,16 +260,16 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {maskMimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create mask form part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create mask form part", err) } if _, err := maskPart.Write(openaiReq.Mask); err != nil { - return providerUtils.NewBifrostOperationError("failed to write mask data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write mask data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil @@ -299,12 +299,12 @@ func ToOpenAIImageVariationRequest(bifrostReq *schemas.BifrostImageVariationRequ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageVariationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add image file (required) if openaiReq.Input == nil || openaiReq.Input.Image.Image == nil || len(openaiReq.Input.Image.Image) == 0 { - return providerUtils.NewBifrostOperationError("image is required", nil, providerName) + return providerUtils.NewBifrostOperationError("image is required", nil) } // Detect MIME type @@ -320,41 +320,41 @@ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openai "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create image part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create image part", err) } if _, err := part.Write(openaiReq.Input.Image.Image); err != nil { - return providerUtils.NewBifrostOperationError("failed to write image data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write image data", err) } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/large_payload.go b/core/providers/openai/large_payload.go index 461f3417de..fe3aaf1812 100644 --- a/core/providers/openai/large_payload.go +++ b/core/providers/openai/large_payload.go @@ -42,8 +42,6 @@ func handleOpenAILargePayloadPassthrough( key schemas.Key, extraHeaders map[string]string, providerName schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) (*largePayloadResult, *schemas.BifrostError, bool) { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) @@ -91,7 +89,7 @@ func handleOpenAILargePayloadPassthrough( // Error responses are always small β€” materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - parsedErr := ParseOpenAIError(resp, requestType, providerName, model) + parsedErr := ParseOpenAIError(resp) fasthttp.ReleaseResponse(resp) return nil, parsedErr, true } @@ -126,7 +124,7 @@ func finalizeOpenAIResponse( providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, *largePayloadResult, *schemas.BifrostError) { - body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, logger) + body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, logger) if bifrostErr != nil { fasthttp.ReleaseResponse(resp) return nil, nil, bifrostErr diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index d00a8af112..a76d350d28 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -1,13 +1,14 @@ package openai import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -16,38 +17,39 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ID, - Created: model.Created, - OwnedBy: schemas.Ptr(model.OwnedBy), - ContextLength: model.ContextWindow, - }) - includedModels[model.ID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Created: model.Created, + OwnedBy: schemas.Ptr(model.OwnedBy), + ContextLength: model.ContextWindow, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 387da52431..197334a22c 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -166,7 +166,7 @@ func ListModelsByKey( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseOpenAIError(resp) return nil, bifrostErr } @@ -181,10 +181,8 @@ func ListModelsByKey( return nil, bifrostErr } - response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, unfiltered) + response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, unfiltered) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -289,22 +287,22 @@ func HandleOpenAITextCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TextCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTextCompletionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -313,8 +311,7 @@ func HandleOpenAITextCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -335,9 +332,9 @@ func HandleOpenAITextCompletionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -349,7 +346,7 @@ func HandleOpenAITextCompletionRequest( return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -367,9 +364,6 @@ func HandleOpenAITextCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -455,8 +449,7 @@ func HandleOpenAITextCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -501,9 +494,9 @@ func HandleOpenAITextCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -514,9 +507,9 @@ func HandleOpenAITextCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -533,9 +526,9 @@ func HandleOpenAITextCompletionStreaming( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -557,7 +550,7 @@ func HandleOpenAITextCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -584,7 +577,7 @@ func HandleOpenAITextCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -595,11 +588,6 @@ func HandleOpenAITextCompletionStreaming( rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { // TODO fix this - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -618,11 +606,6 @@ func HandleOpenAITextCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -699,9 +682,6 @@ func HandleOpenAITextCompletionStreaming( if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { chunkIndex++ - response.ExtraFields.RequestType = schemas.TextCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -719,7 +699,7 @@ func HandleOpenAITextCompletionStreaming( } } - response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) + response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -811,22 +791,22 @@ func HandleOpenAIChatCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ChatCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostChatResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -835,8 +815,7 @@ func HandleOpenAIChatCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -858,9 +837,9 @@ func HandleOpenAIChatCompletionRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -872,7 +851,7 @@ func HandleOpenAIChatCompletionRequest( return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } response := &schemas.BifrostChatResponse{} @@ -890,9 +869,6 @@ func HandleOpenAIChatCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ChatCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1008,8 +984,7 @@ func HandleOpenAIChatCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1054,9 +1029,9 @@ func HandleOpenAIChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1067,9 +1042,9 @@ func HandleOpenAIChatCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1082,19 +1057,13 @@ func HandleOpenAIChatCompletionStreaming( // Create response channel responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) - // Determine request type for cleanup - streamRequestType := schemas.ChatCompletionStreamRequest - if isResponsesToChatCompletionsFallback { - streamRequestType = schemas.ResponsesStreamRequest - } - // Start streaming in a goroutine go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } // Release the responses stream state if it was acquired (for ResponsesToChatCompletions fallback) schemas.ReleaseChatToResponsesStreamState(responsesStreamState) @@ -1118,7 +1087,7 @@ func HandleOpenAIChatCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1132,6 +1101,8 @@ func HandleOpenAIChatCompletionStreaming( var finishReason *string var messageID string + var modelName string + var created int forwardedTerminalFinishReason := false for { @@ -1147,7 +1118,7 @@ func HandleOpenAIChatCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -1160,11 +1131,6 @@ func HandleOpenAIChatCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1178,11 +1144,6 @@ func HandleOpenAIChatCompletionStreaming( if customResponseHandler != nil { rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -1213,11 +1174,6 @@ func HandleOpenAIChatCompletionStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: streamRequestType, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1235,9 +1191,6 @@ func HandleOpenAIChatCompletionStreaming( return } - response.ExtraFields.RequestType = streamRequestType - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber if sendBackRawResponse { @@ -1300,6 +1253,10 @@ func HandleOpenAIChatCompletionStreaming( response.Usage = nil } + if response.Model != "" { + modelName = response.Model + } + // Skip empty responses or responses without choices if len(response.Choices) == 0 { continue @@ -1315,6 +1272,9 @@ func HandleOpenAIChatCompletionStreaming( if response.ID != "" && messageID == "" { messageID = response.ID } + if response.Created != 0 && created == 0 { + created = response.Created + } // Handle regular content chunks, including reasoning if choice.ChatStreamResponseChoice != nil && @@ -1329,9 +1289,6 @@ func HandleOpenAIChatCompletionStreaming( } chunkIndex++ - response.ExtraFields.RequestType = schemas.ChatCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -1355,7 +1312,7 @@ func HandleOpenAIChatCompletionStreaming( if forwardedTerminalFinishReason { finalFinishReason = nil } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, streamRequestType, providerName, request.Model) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, modelName, created) if postResponseConverter != nil { response = postResponseConverter(response) } @@ -1442,21 +1399,21 @@ func HandleOpenAIResponsesRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ResponsesRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostResponsesResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1466,8 +1423,7 @@ func HandleOpenAIResponsesRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1489,9 +1445,9 @@ func HandleOpenAIResponsesRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1502,7 +1458,7 @@ func HandleOpenAIResponsesRequest( if lpResult != nil { return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1520,9 +1476,6 @@ func HandleOpenAIResponsesRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ResponsesRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1619,8 +1572,7 @@ func HandleOpenAIResponsesStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1664,9 +1616,9 @@ func HandleOpenAIResponsesStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1677,9 +1629,9 @@ func HandleOpenAIResponsesStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -1696,9 +1648,9 @@ func HandleOpenAIResponsesStreaming( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1720,7 +1672,7 @@ func HandleOpenAIResponsesStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1742,7 +1694,7 @@ func HandleOpenAIResponsesStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1754,11 +1706,6 @@ func HandleOpenAIResponsesStreaming( if customResponseHandler != nil { rawRequest, rawResponse, bifrostErr := customResponseHandler([]byte(jsonData), &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } if sendBackRawRequest { bifrostErr.ExtraFields.RawRequest = rawRequest } @@ -1792,11 +1739,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1821,11 +1763,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeFailed)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Response != nil && response.Response.Error != nil { bifrostErr.Error.Message = response.Response.Error.Message @@ -1836,11 +1773,7 @@ func HandleOpenAIResponsesStreaming( return } - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber - if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete { // Set raw request if enabled if sendBackRawRequest { @@ -1858,7 +1791,6 @@ func HandleOpenAIResponsesStreaming( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan) } } - }() return responseChan, nil @@ -1929,22 +1861,22 @@ func HandleOpenAIEmbeddingRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.EmbeddingRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostEmbeddingResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1954,8 +1886,7 @@ func HandleOpenAIEmbeddingRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIEmbeddingRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1976,7 +1907,7 @@ func HandleOpenAIEmbeddingRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.EmbeddingRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1988,7 +1919,7 @@ func HandleOpenAIEmbeddingRequest( return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2006,9 +1937,6 @@ func HandleOpenAIEmbeddingRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.EmbeddingRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2087,22 +2015,21 @@ func HandleOpenAISpeechRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.SpeechRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } // Speech response is raw audio bytes (MP3/WAV), not JSON return &schemas.BifrostSpeechResponse{ Audio: lpResult.ResponseBody, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }, - providerName) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }) if bifrostErr != nil { return nil, bifrostErr } @@ -2123,7 +2050,7 @@ func HandleOpenAISpeechRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Get the binary audio data from the response body @@ -2134,7 +2061,7 @@ func HandleOpenAISpeechRequest( } if lpResult != nil { return &schemas.BifrostSpeechResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2144,9 +2071,6 @@ func HandleOpenAISpeechRequest( bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: body, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2169,7 +2093,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo for _, model := range providerUtils.UnsupportedSpeechStreamModels { if model == request.Model { - return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil) } } @@ -2254,8 +2178,7 @@ func HandleOpenAISpeechStreamRequest( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2281,9 +2204,9 @@ func HandleOpenAISpeechStreamRequest( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -2293,7 +2216,7 @@ func HandleOpenAISpeechStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -2310,9 +2233,9 @@ func HandleOpenAISpeechStreamRequest( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2334,7 +2257,7 @@ func HandleOpenAISpeechStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2358,7 +2281,7 @@ func HandleOpenAISpeechStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2370,11 +2293,6 @@ func HandleOpenAISpeechStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -2400,11 +2318,8 @@ func HandleOpenAISpeechStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2425,7 +2340,6 @@ func HandleOpenAISpeechStreamRequest( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) } - }() return responseChan, nil @@ -2466,7 +2380,7 @@ func HandleOpenAITranscriptionRequest( logger schemas.Logger, ) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TranscriptionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } @@ -2474,13 +2388,13 @@ func HandleOpenAITranscriptionRequest( if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTranscriptionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2509,7 +2423,7 @@ func HandleOpenAITranscriptionRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form @@ -2536,7 +2450,7 @@ func HandleOpenAITranscriptionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -2546,7 +2460,7 @@ func HandleOpenAITranscriptionRequest( } if lpResult != nil { return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2566,7 +2480,12 @@ func HandleOpenAITranscriptionRequest( // Parse OpenAI's transcription response directly into BifrostTranscribe response := &schemas.BifrostTranscriptionResponse{} var rawResponse interface{} - if customResponseHandler != nil { + if request.Params != nil && schemas.IsPlainTextTranscriptionFormat(request.Params.ResponseFormat) { + response.Text = string(copiedResponseBody) + if sendBackRawResponse { + rawResponse = string(copiedResponseBody) + } + } else if customResponseHandler != nil { _, rawResponse, bifrostErr = customResponseHandler(copiedResponseBody, response, nil, false, sendBackRawResponse) } else { if err := sonic.Unmarshal(copiedResponseBody, response); err != nil { @@ -2580,15 +2499,15 @@ func HandleOpenAITranscriptionRequest( }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - //TODO: add HandleProviderResponse here + // TODO: add HandleProviderResponse here // Parse raw response for RawResponse field if sendBackRawResponse { if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err) } } } @@ -2598,9 +2517,6 @@ func HandleOpenAITranscriptionRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -2662,7 +2578,7 @@ func HandleOpenAITranscriptionStreamRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) if postRequestConverter != nil { @@ -2723,9 +2639,9 @@ func HandleOpenAITranscriptionStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -2735,7 +2651,7 @@ func HandleOpenAITranscriptionStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -2752,9 +2668,9 @@ func HandleOpenAITranscriptionStreamRequest( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2776,7 +2692,7 @@ func HandleOpenAITranscriptionStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2801,7 +2717,7 @@ func HandleOpenAITranscriptionStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2812,11 +2728,6 @@ func HandleOpenAITranscriptionStreamRequest( if customResponseHandler != nil { _, _, bifrostErr = customResponseHandler([]byte(jsonData), response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } if sendBackRawResponse { bifrostErr.ExtraFields.RawResponse = jsonData } @@ -2831,13 +2742,9 @@ func HandleOpenAITranscriptionStreamRequest( var bifrostErrVal schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErrVal); err == nil { if bifrostErrVal.Error != nil && bifrostErrVal.Error.Message != "" { - bifrostErrVal.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, nil, nil, false, sendBackRawResponse), responseChan, logger) + respBody := append([]byte(nil), resp.Body()...) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger) return } } @@ -2861,11 +2768,8 @@ func HandleOpenAITranscriptionStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2887,7 +2791,6 @@ func HandleOpenAITranscriptionStreamRequest( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) } - }() return responseChan, nil @@ -2897,8 +2800,8 @@ func HandleOpenAITranscriptionStreamRequest( // It formats the request, sends it to OpenAI, and processes the response. // Returns a BifrostResponse containing the bifrost response or an error if the request fails. func (provider *OpenAIProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, - req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - + req *schemas.BifrostImageGenerationRequest, +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil { return nil, err } @@ -2931,7 +2834,6 @@ func HandleOpenAIImageGenerationRequest( sendBackRawResponse bool, logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2957,20 +2859,20 @@ func HandleOpenAIImageGenerationRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageGenerationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2980,8 +2882,7 @@ func HandleOpenAIImageGenerationRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3002,7 +2903,7 @@ func HandleOpenAIImageGenerationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -3012,7 +2913,7 @@ func HandleOpenAIImageGenerationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -3024,9 +2925,6 @@ func HandleOpenAIImageGenerationRequest( return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -3052,9 +2950,8 @@ func (provider *OpenAIProvider) ImageGenerationStream( key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } // Check if image generation stream is allowed for this provider @@ -3101,7 +2998,6 @@ func HandleOpenAIImageGenerationStreaming( postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse, logger schemas.Logger, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - // Set headers headers := map[string]string{ "Content-Type": "application/json", @@ -3129,8 +3025,7 @@ func HandleOpenAIImageGenerationStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3175,9 +3070,9 @@ func HandleOpenAIImageGenerationStreaming( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -3187,7 +3082,7 @@ func HandleOpenAIImageGenerationStreaming( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -3204,9 +3099,9 @@ func HandleOpenAIImageGenerationStreaming( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -3228,7 +3123,7 @@ func HandleOpenAIImageGenerationStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -3255,7 +3150,7 @@ func HandleOpenAIImageGenerationStreaming( if readErr != nil { if readErr != io.EOF { logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -3267,11 +3162,6 @@ func HandleOpenAIImageGenerationStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -3291,11 +3181,6 @@ func HandleOpenAIImageGenerationStreaming( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -3396,11 +3281,8 @@ func HandleOpenAIImageGenerationStreaming( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -3463,7 +3345,6 @@ func HandleOpenAIImageGenerationStreaming( return } } - }() return responseChan, nil @@ -3502,7 +3383,7 @@ func (provider *OpenAIProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3531,7 +3412,7 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3572,12 +3453,12 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -3595,8 +3476,6 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3612,7 +3491,7 @@ func (provider *OpenAIProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3682,10 +3561,10 @@ func HandleOpenAIVideoGenerationRequest( // Use centralized converter reqBody, err := ToOpenAIVideoGenerationRequest(request) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil) } // Create multipart form @@ -3711,12 +3590,12 @@ func HandleOpenAIVideoGenerationRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoGenerationRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -3742,9 +3621,6 @@ func HandleOpenAIVideoGenerationRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3809,12 +3685,12 @@ func HandleOpenAIVideoRetrieveRequest( if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRetrieveRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoGenerationResponse{} @@ -3856,8 +3732,6 @@ func HandleOpenAIVideoRetrieveRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3909,12 +3783,12 @@ func HandleOpenAIVideoDeleteRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDeleteRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse OpenAI's video response @@ -3928,8 +3802,6 @@ func HandleOpenAIVideoDeleteRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4002,12 +3874,12 @@ func HandleOpenAIVideoListRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoListResponse{} @@ -4034,8 +3906,6 @@ func HandleOpenAIVideoListRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4105,20 +3975,20 @@ func HandleOpenAICountTokensRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.CountTokensRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostCountTokensResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4127,9 +3997,7 @@ func HandleOpenAICountTokensRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -4150,7 +4018,7 @@ func HandleOpenAICountTokensRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.CountTokensRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4160,7 +4028,7 @@ func HandleOpenAICountTokensRequest( } if lpResult != nil { return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4173,9 +4041,6 @@ func HandleOpenAICountTokensRequest( } response.Model = request.Model - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4223,26 +4088,26 @@ func HandleOpenAIImageEditRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageEditRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageEditRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4289,7 +4154,7 @@ func HandleOpenAIImageEditRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4299,7 +4164,7 @@ func HandleOpenAIImageEditRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4308,9 +4173,6 @@ func HandleOpenAIImageEditRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4372,10 +4234,9 @@ func HandleOpenAIImageEditStreamRequest( postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse, logger schemas.Logger, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - reqBody := ToOpenAIImageEditRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -4435,9 +4296,9 @@ func HandleOpenAIImageEditStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -4446,7 +4307,7 @@ func HandleOpenAIImageEditStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditStreamRequest, providerName, request.Model), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -4463,9 +4324,9 @@ func HandleOpenAIImageEditStreamRequest( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -4487,7 +4348,7 @@ func HandleOpenAIImageEditStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -4514,7 +4375,7 @@ func HandleOpenAIImageEditStreamRequest( if readErr != nil { if readErr != io.EOF { logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -4526,11 +4387,6 @@ func HandleOpenAIImageEditStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -4550,11 +4406,6 @@ func HandleOpenAIImageEditStreamRequest( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -4655,11 +4506,8 @@ func HandleOpenAIImageEditStreamRequest( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -4718,7 +4566,6 @@ func HandleOpenAIImageEditStreamRequest( return } } - }() return responseChan, nil @@ -4760,26 +4607,26 @@ func HandleOpenAIImageVariationRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageVariationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageVariationRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4825,7 +4672,7 @@ func HandleOpenAIImageVariationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageVariationRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4835,7 +4682,7 @@ func HandleOpenAIImageVariationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4844,9 +4691,6 @@ func HandleOpenAIImageVariationRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageVariationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4863,14 +4707,12 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Create multipart form data @@ -4879,16 +4721,16 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add expires_after fields if provided if request.ExpiresAfter != nil { if err := writer.WriteField("expires_after[anchor]", request.ExpiresAfter.Anchor); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err) } if err := writer.WriteField("expires_after[seconds]", fmt.Sprintf("%d", request.ExpiresAfter.Seconds)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err) } } @@ -4899,14 +4741,14 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -4936,13 +4778,13 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp OpenAIFileResponse @@ -4953,7 +4795,7 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, bifrostErr } - fileResponse := openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) + fileResponse := openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) fileResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) return fileResponse, nil } @@ -4972,7 +4814,7 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -4983,10 +4825,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -5036,12 +4874,12 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIFileListResponse @@ -5077,8 +4915,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -5099,7 +4935,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5134,7 +4970,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5144,7 +4980,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5175,7 +5011,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5210,7 +5046,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5220,7 +5056,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5241,9 +5077,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5270,7 +5104,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -5301,7 +5135,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5311,7 +5145,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5330,9 +5164,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -5349,10 +5181,10 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } if request.Input == nil || request.Input.Prompt == "" { - return nil, providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("prompt is required", nil) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -5360,8 +5192,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIVideoRemixRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -5399,12 +5230,12 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRemixRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Parse OpenAI's video response @@ -5422,9 +5253,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRemixRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } if sendBackRawResponse { @@ -5443,8 +5272,6 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -5452,7 +5279,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Convert inline requests to JSONL format jsonlData, err := ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" @@ -5471,12 +5298,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil) } // Validate that we have an endpoint if request.Endpoint == "" { - return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil) } // Create request @@ -5511,7 +5338,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -5527,12 +5354,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, ""), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var openAIResp OpenAIBatchResponse @@ -5541,7 +5368,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -5551,14 +5378,13 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -5569,10 +5395,6 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -5616,12 +5438,12 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIBatchListResponse @@ -5634,7 +5456,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -5648,9 +5470,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -5667,10 +5487,9 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, request.Provider) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5702,7 +5521,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5712,7 +5531,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5728,8 +5547,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -5743,10 +5561,9 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5778,7 +5595,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5788,7 +5605,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5811,9 +5628,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5853,11 +5668,9 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (this already iterates over keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -5868,7 +5681,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file - try each key @@ -5898,7 +5711,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5908,7 +5721,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5932,9 +5745,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5954,14 +5765,12 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.Name == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil) } // Build request body @@ -5997,7 +5806,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key jsonBody, err := providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create request @@ -6026,7 +5835,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, ParseOpenAIError(resp, schemas.ContainerCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6060,9 +5869,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6079,16 +5886,14 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // ContainerList lists containers via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } @@ -6102,7 +5907,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6113,10 +5918,6 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys Object: "list", Data: []schemas.ContainerObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - }, }, nil } @@ -6163,7 +5964,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.ContainerListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6198,9 +5999,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6221,20 +6020,18 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // ContainerRetrieve retrieves a specific container via OpenAI's API. func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerRetrieveRequest); err != nil { @@ -6269,7 +6066,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6309,9 +6106,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6332,20 +6127,18 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // ContainerDelete deletes a container via OpenAI's API. func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerDeleteRequest); err != nil { @@ -6380,7 +6173,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6408,9 +6201,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6439,14 +6230,12 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } // Create request @@ -6463,7 +6252,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle file upload (multipart only) if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil) } // Multipart file upload @@ -6473,13 +6262,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Add file part, err := writer.CreateFormFile("file", "file") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err) } if _, err = part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err) } req.Header.Set("Content-Type", writer.FormDataContentType()) req.SetBody(body.Bytes()) @@ -6497,13 +6286,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -6532,9 +6321,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6552,21 +6339,19 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // ContainerFileList lists files in a container via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6580,7 +6365,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6591,10 +6376,6 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k Object: "list", Data: []schemas.ContainerFileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - }, }, nil } @@ -6640,13 +6421,13 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k } if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var listResp struct { @@ -6678,9 +6459,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6701,13 +6480,11 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // ContainerFileRetrieve retrieves a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6716,15 +6493,15 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6752,7 +6529,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6761,7 +6538,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6796,9 +6573,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6819,13 +6594,11 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // ContainerFileContent retrieves the content of a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6834,15 +6607,15 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6870,7 +6643,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6887,7 +6660,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } content := append([]byte(nil), body...) @@ -6896,9 +6669,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileContentRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6922,13 +6693,11 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext // ContainerFileDelete deletes a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6937,15 +6706,15 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6973,7 +6742,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6982,7 +6751,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -7009,9 +6778,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -7080,7 +6847,7 @@ func (provider *OpenAIProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -7096,9 +6863,6 @@ func (provider *OpenAIProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -7166,9 +6930,9 @@ func (provider *OpenAIProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := make(map[string]string) @@ -7182,9 +6946,7 @@ func (provider *OpenAIProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Wrap reader with idle timeout to detect stalled streams. @@ -7193,11 +6955,7 @@ func (provider *OpenAIProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -7207,9 +6965,9 @@ func (provider *OpenAIProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -7258,7 +7016,7 @@ func (provider *OpenAIProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/openai/openai_test.go b/core/providers/openai/openai_test.go index c37040ce62..d2173e1ac7 100644 --- a/core/providers/openai/openai_test.go +++ b/core/providers/openai/openai_test.go @@ -46,69 +46,69 @@ func TestOpenAI(t *testing.T) { ChatAudioModel: "gpt-4o-mini-audio-preview", PassthroughModel: "gpt-4o", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - WebSearchTool: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - FileBase64: true, - FileURL: true, - CompleteEnd2End: true, - SpeechSynthesis: true, - SpeechSynthesisStream: true, - Transcription: true, - TranscriptionStream: true, - Embedding: true, - Reasoning: true, - ListModels: true, - ImageGeneration: true, - ImageGenerationStream: true, - ImageEdit: true, - ImageEditStream: true, - ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations - VideoGeneration: false, // disabled for now because of long running operations - VideoRetrieve: false, - VideoRemix: false, - VideoDownload: false, - VideoList: false, - VideoDelete: false, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: true, - FileBatchInput: true, - CountTokens: true, - ChatAudio: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - ContainerCreate: true, - ContainerList: true, - ContainerRetrieve: true, - ContainerDelete: true, - ContainerFileCreate: true, - ContainerFileList: true, - ContainerFileRetrieve: true, - ContainerFileContent: true, - ContainerFileDelete: true, - PromptCaching: true, - PassthroughAPI: true, - WebSocketResponses: true, - Realtime: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + WebSearchTool: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + FileBase64: true, + FileURL: true, + CompleteEnd2End: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + Reasoning: true, + ListModels: true, + ImageGeneration: true, + ImageGenerationStream: true, + ImageEdit: true, + ImageEditStream: true, + ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations + VideoGeneration: false, // disabled for now because of long running operations + VideoRetrieve: false, + VideoRemix: false, + VideoDownload: false, + VideoList: false, + VideoDelete: false, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: true, + FileBatchInput: true, + CountTokens: true, + ChatAudio: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + ContainerCreate: true, + ContainerList: true, + ContainerRetrieve: true, + ContainerDelete: true, + ContainerFileCreate: true, + ContainerFileList: true, + ContainerFileRetrieve: true, + ContainerFileContent: true, + ContainerFileDelete: true, + PromptCaching: true, + PassthroughAPI: true, + WebSocketResponses: true, + Realtime: false, }, RealtimeModel: "gpt-4o-realtime-preview", } diff --git a/core/providers/openai/realtime.go b/core/providers/openai/realtime.go index b73db4ea24..8c88382297 100644 --- a/core/providers/openai/realtime.go +++ b/core/providers/openai/realtime.go @@ -1,13 +1,17 @@ package openai import ( + "bytes" "encoding/json" "fmt" + "mime/multipart" + "net/http" "net/url" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" ) // SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API. @@ -28,7 +32,6 @@ func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model stri func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + key.Value.GetValue(), - "OpenAI-Beta": "realtime=v1", } for k, v := range provider.networkConfig.ExtraHeaders { headers[k] = v @@ -36,6 +39,380 @@ func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]stri return headers } +// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange. +func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool { + return true +} + +// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls. +func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + model string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + path := "/v1/realtime/calls" + if session == nil && strings.TrimSpace(model) != "" { + path += "?model=" + url.QueryEscape(model) + } + return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session) +} + +// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime. +// Same multipart format but targets the legacy endpoint with model in the URL. +func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + sdp string, + session json.RawMessage, + model string, +) (string, *schemas.BifrostError) { + return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session) +} + +// exchangeWebRTCSDP is the shared multipart SDP exchange implementation. +// Builds a multipart body with sdp + optional session, POSTs to the given path. +func (provider *OpenAIProvider) exchangeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + path string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + bodyBuf := &bytes.Buffer{} + writer := multipart.NewWriter(bodyBuf) + if err := writer.WriteField("sdp", sdp); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err) + } + if session != nil { + if err := writer.WriteField("session", string(session)); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err) + } + } + if err := writer.Close(); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err) + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType(writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+key.Value.GetValue()) + for k, v := range provider.networkConfig.ExtraHeaders { + req.Header.Set(k, v) + } + if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil { + if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" { + req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK) + } + } + req.SetBody(bodyBuf.Bytes()) + + _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return "", bifrostErr + } + + answerBody := resp.Body() + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody) + } + + return string(answerBody), nil +} + +func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(fasthttp.StatusBadGateway), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("upstream_connection_error"), + Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider.GetProviderKey(), + }, + } + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostErr.ExtraFields.RawResponse = map[string]any{ + "status": statusCode, + "body": string(body), + } + } + return bifrostErr +} + +func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + }, + } + if err != nil { + bifrostErr.Error.Error = err + } + return bifrostErr +} + +func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + if event == nil { + return false + } + switch event.Type { + case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted: + return true + default: + return false + } +} + +func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string { + return "oai-events" +} + +func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string { + return "realtime" +} + +func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + switch eventType { + case schemas.RTEventResponseTextDelta, + schemas.RTEventResponseAudioTransDelta, + schemas.RealtimeEventType("response.output_text.delta"), + schemas.RealtimeEventType("response.output_audio_transcript.delta"): + return true + default: + return false + } +} + +// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns +// the native OpenAI response body unchanged. +func (provider *OpenAIProvider) CreateRealtimeClientSecret( + ctx *schemas.BifrostContext, + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, + rawRequest json.RawMessage, +) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil { + return nil, err + } + + normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType) + if bifrostErr != nil { + return nil, bifrostErr + } + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + for k, v := range provider.realtimeSessionHeaders(key, endpointType) { + req.Header.Set(k, v) + } + req.SetBody(normalizedBody) + + latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return nil, bifrostErr + } + + headers := providerUtils.ExtractProviderResponseHeaders(resp) + ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers) + + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return nil, ParseOpenAIError(resp) + } + + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) + } + for k := range headers { + if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { + delete(headers, k) + } + } + + out := &schemas.BifrostPassthroughResponse{ + StatusCode: resp.StatusCode(), + Headers: headers, + Body: body, + } + out.ExtraFields.Provider = provider.GetProviderKey() + out.ExtraFields.OriginalModelRequested = requestedModel + out.ExtraFields.RequestType = schemas.RealtimeRequest + out.ExtraFields.Latency = latency.Milliseconds() + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields) + } + + return out, nil +} + +func normalizeRealtimeClientSecretRequest( + rawRequest json.RawMessage, + defaultProvider schemas.ModelProvider, + endpointType schemas.RealtimeSessionEndpointType, +) ([]byte, string, *schemas.BifrostError) { + root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest) + if bifrostErr != nil { + return nil, "", bifrostErr + } + + modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root) + if bifrostErr != nil { + return nil, "", bifrostErr + } + providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider) + if normalizedModel == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + if providerKey == "" { + providerKey = defaultProvider + } + if providerKey == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil) + } + + if endpointType == schemas.RealtimeSessionEndpointSessions { + return normalizeRealtimeSessionsRequest(root, normalizedModel) + } + + return normalizeRealtimeClientSecretsRequest(root, normalizedModel) +} + +func normalizeRealtimeClientSecretsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + session := map[string]json.RawMessage{} + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + session["model"] = modelJSON + if _, ok := session["type"]; !ok { + typeJSON, marshalErr := json.Marshal("realtime") + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr) + } + session["type"] = typeJSON + } + delete(root, "model") + + sessionJSON, marshalErr := json.Marshal(session) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr) + } + root["session"] = sessionJSON + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func normalizeRealtimeSessionsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + session := map[string]json.RawMessage{} + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + for key, value := range session { + if _, exists := root[key]; !exists { + root[key] = value + } + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + root["model"] = modelJSON + delete(root, "session") + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func (provider *OpenAIProvider) realtimeSessionHeaders( + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, +) map[string]string { + headers := map[string]string{ + "Authorization": "Bearer " + key.Value.GetValue(), + } + if endpointType == schemas.RealtimeSessionEndpointSessions { + headers["OpenAI-Beta"] = "realtime=v1" + } + for k, v := range provider.networkConfig.ExtraHeaders { + headers[k] = v + } + return headers +} + +func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string { + if endpointType == schemas.RealtimeSessionEndpointSessions { + return "/v1/realtime/sessions" + } + return "/v1/realtime/client_secrets" +} + +func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: schemas.OpenAI, + }, + } +} + // openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event. type openAIRealtimeEvent struct { Type string `json:"type"` @@ -44,15 +421,17 @@ type openAIRealtimeEvent struct { Conversation json.RawMessage `json:"conversation,omitempty"` Item json.RawMessage `json:"item,omitempty"` Response json.RawMessage `json:"response,omitempty"` + Part json.RawMessage `json:"part,omitempty"` Delta string `json:"delta,omitempty"` Audio string `json:"audio,omitempty"` Transcript string `json:"transcript,omitempty"` Text string `json:"text,omitempty"` Error json.RawMessage `json:"error,omitempty"` ItemID string `json:"item_id,omitempty"` - OutputIndex int `json:"output_index,omitempty"` - ContentIndex int `json:"content_index,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` ResponseID string `json:"response_id,omitempty"` + AudioEndMS *int `json:"audio_end_ms,omitempty"` PreviousItemID string `json:"previous_item_id,omitempty"` } @@ -105,6 +484,17 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes EventID: raw.EventID, RawData: providerEvent, } + setRealtimeExtraParam(event, "item_id", raw.ItemID) + setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID) + setRealtimeExtraParam(event, "output_index", raw.OutputIndex) + setRealtimeExtraParam(event, "content_index", raw.ContentIndex) + setRealtimeExtraParam(event, "response_id", raw.ResponseID) + setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS) + setRealtimeExtraParam(event, "transcript", raw.Transcript) + setRealtimeExtraParam(event, "text", raw.Text) + setRealtimeExtraParam(event, "conversation", raw.Conversation) + setRealtimeExtraParam(event, "response", raw.Response) + setRealtimeExtraParam(event, "part", raw.Part) switch { case raw.Session != nil: @@ -123,8 +513,10 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes OutputAudioType: sess.OutputAudioType, Tools: sess.Tools, } + if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 { + event.Session.ExtraParams = extra + } } - case raw.Item != nil: var item openAIRealtimeItem if err := json.Unmarshal(raw.Item, &item); err == nil { @@ -139,6 +531,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Arguments: item.Arguments, Output: item.Output, } + if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 { + event.Item.ExtraParams = extra + } } case raw.Error != nil: @@ -150,6 +545,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Message: rtErr.Message, Param: rtErr.Param, } + if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 { + event.Error.ExtraParams = extra + } } } @@ -159,8 +557,8 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Audio: raw.Audio, Transcript: raw.Transcript, ItemID: raw.ItemID, - OutputIdx: &raw.OutputIndex, - ContentIdx: &raw.ContentIndex, + OutputIdx: raw.OutputIndex, + ContentIdx: raw.ContentIndex, ResponseID: raw.ResponseID, } if raw.Delta != "" { @@ -175,19 +573,19 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes // ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON. func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - out := map[string]interface{}{ "type": string(bifrostEvent.Type), } if bifrostEvent.EventID != "" { out["event_id"] = bifrostEvent.EventID } + mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams) if bifrostEvent.Session != nil { sess := map[string]interface{}{} + if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate { + sess["id"] = bifrostEvent.Session.ID + } if bifrostEvent.Session.Model != "" { sess["model"] = bifrostEvent.Session.Model } @@ -218,6 +616,7 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Session.Tools != nil { sess["tools"] = bifrostEvent.Session.Tools } + mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type) out["session"] = sess } @@ -231,6 +630,9 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Role != "" { item["role"] = bifrostEvent.Item.Role } + if bifrostEvent.Item.Status != "" { + item["status"] = bifrostEvent.Item.Status + } if bifrostEvent.Item.Content != nil { item["content"] = bifrostEvent.Item.Content } @@ -246,9 +648,28 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Output != "" { item["output"] = bifrostEvent.Item.Output } + mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams) out["item"] = item } + if bifrostEvent.Error != nil { + rtErr := map[string]interface{}{} + if bifrostEvent.Error.Type != "" { + rtErr["type"] = bifrostEvent.Error.Type + } + if bifrostEvent.Error.Code != "" { + rtErr["code"] = bifrostEvent.Error.Code + } + if bifrostEvent.Error.Message != "" { + rtErr["message"] = bifrostEvent.Error.Message + } + if bifrostEvent.Error.Param != "" { + rtErr["param"] = bifrostEvent.Error.Param + } + mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams) + out["error"] = rtErr + } + if bifrostEvent.Delta != nil { if bifrostEvent.Delta.Text != "" { out["delta"] = bifrostEvent.Delta.Text @@ -259,16 +680,16 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Delta.Transcript != "" { out["transcript"] = bifrostEvent.Delta.Transcript } - if bifrostEvent.Delta.ItemID != "" { + if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") { out["item_id"] = bifrostEvent.Delta.ItemID } - if bifrostEvent.Delta.OutputIdx != nil { + if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") { out["output_index"] = *bifrostEvent.Delta.OutputIdx } - if bifrostEvent.Delta.ContentIdx != nil { + if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") { out["content_index"] = *bifrostEvent.Delta.ContentIdx } - if bifrostEvent.Delta.ResponseID != "" { + if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") { out["response_id"] = bifrostEvent.Delta.ResponseID } } @@ -276,11 +697,269 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi return providerUtils.MarshalSorted(out) } +func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) { + filtered := params + if eventType == schemas.RTEventSessionUpdate && len(params) > 0 { + filtered = make(map[string]json.RawMessage, len(params)) + for key, value := range params { + switch key { + case "id", "object", "expires_at", "client_secret": + continue + default: + filtered[key] = value + } + } + } + mergeRealtimeExtraParams(out, filtered) +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: parsed.Response.Usage.TotalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil { + return nil + } + + content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output) + toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output) + if content == "" && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if content != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls} + } + + return message +} + +type openAIRealtimeResponseDoneEnvelope struct { + Response struct { + Output []openAIRealtimeResponseDoneOutput `json:"output"` + Usage *openAIRealtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type openAIRealtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []openAIRealtimeResponseDoneBlock `json:"content"` +} + +type openAIRealtimeResponseDoneBlock struct { + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type openAIRealtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type openAIRealtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type openAIRealtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string { + var sb strings.Builder + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + sb.WriteString(block.Text) + case strings.TrimSpace(block.Transcript) != "": + sb.WriteString(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + sb.WriteString(block.Refusal) + } + } + } + return strings.TrimSpace(sb.String()) +} + +func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) { + if event == nil || key == "" || value == nil { + return + } + + switch v := value.(type) { + case string: + if v == "" { + return + } + case *int: + if v == nil { + return + } + case json.RawMessage: + if len(v) == 0 || string(v) == "null" { + return + } + } + + raw, err := json.Marshal(value) + if err != nil { + return + } + if event.ExtraParams == nil { + event.ExtraParams = make(map[string]json.RawMessage) + } + event.ExtraParams[key] = raw +} + +func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) { + for key, raw := range params { + if len(raw) == 0 { + continue + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + continue + } + out[key] = value + } +} + +func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool { + if params == nil { + return false + } + raw, ok := params[key] + return ok && len(raw) > 0 +} + +func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage { + if len(raw) == 0 { + return nil + } + root := map[string]json.RawMessage{} + if err := json.Unmarshal(raw, &root); err != nil { + return nil + } + for _, key := range knownKeys { + delete(root, key) + } + if len(root) == 0 { + return nil + } + return root +} + func isRealtimeDeltaEvent(eventType string) bool { switch eventType { case "response.text.delta", + "response.output_text.delta", "response.audio.delta", + "response.output_audio.delta", "response.audio_transcript.delta", + "response.output_audio_transcript.delta", "conversation.item.input_audio_transcription.delta": return true } diff --git a/core/providers/openai/realtime_test.go b/core/providers/openai/realtime_test.go new file mode 100644 index 0000000000..6b7f76f98f --- /dev/null +++ b/core/providers/openai/realtime_test.go @@ -0,0 +1,561 @@ +package openai + +import ( + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestNormalizeRealtimeClientSecretRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["model"]; ok { + t.Fatal("top-level model should be removed after normalization") + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeSessionsRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointSessions, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["session"]; ok { + t.Fatal("legacy sessions endpoint should not forward nested session object") + } + if payload["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview") + } + if payload["voice"] != "alloy" { + t.Fatalf("voice = %v, want %q", payload["voice"], "alloy") + } +} + +func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + contentIndex, err := json.Marshal(0) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + audioEndMS, err := json.Marshal(640) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.truncate"), + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "content_index": contentIndex, + "audio_end_ms": audioEndMS, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "conversation.item.truncate" { + t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate") + } + if payload["item_id"] != "item_123" { + t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123") + } + if payload["content_index"] != float64(0) { + t.Fatalf("content_index = %v, want 0", payload["content_index"]) + } + if payload["audio_end_ms"] != float64(640) { + t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"]) + } +} + +func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } + var audioEndMS int + if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil { + t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err) + } + if audioEndMS != 640 { + t.Fatalf("audio_end_ms = %d, want 640", audioEndMS) + } +} + +func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var transcript string + if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil { + t.Fatalf("json.Unmarshal(transcript) error = %v", err) + } + if transcript != "Who are you?" { + t.Fatalf("transcript = %q, want %q", transcript, "Who are you?") + } +} + +func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.output_text.delta", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "response_id":"resp_123", + "delta":"hello" + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + if event.Delta == nil || event.Delta.Text != "hello" { + t.Fatalf("Delta = %+v, want text=hello", event.Delta) + } +} + +func TestShouldStartRealtimeTurn(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + tests := []struct { + name string + event *schemas.BifrostRealtimeEvent + want bool + }{ + { + name: "response create starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate}, + want: true, + }, + { + name: "audio buffer committed starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted}, + want: true, + }, + { + name: "response done does not start turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone}, + want: false, + }, + { + name: "nil event does not start turn", + event: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want { + t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + outputIndex := 0 + contentIndex := 0 + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("response.output_text.delta"), + Delta: &schemas.RealtimeDelta{ + Text: "hello", + ItemID: "item_123", + OutputIdx: &outputIndex, + ContentIdx: &contentIndex, + ResponseID: "resp_123", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "response.output_text.delta" { + t.Fatalf("type = %v, want response.output_text.delta", payload["type"]) + } + if payload["delta"] != "hello" { + t.Fatalf("delta = %v, want hello", payload["delta"]) + } +} + +func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionCreated, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + session, ok := payload["session"].(map[string]any) + if !ok { + t.Fatalf("session = %T, want object", payload["session"]) + } + if session["id"] != "sess_123" { + t.Fatalf("session.id = %v, want sess_123", session["id"]) + } +} + +func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`) + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.retrieved"), + Item: &schemas.RealtimeItem{ + ID: "item_123", + Type: "message", + Role: "user", + Status: "completed", + Content: content, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + item, ok := payload["item"].(map[string]any) + if !ok { + t.Fatalf("item = %T, want object", payload["item"]) + } + if item["status"] != "completed" { + t.Fatalf("item.status = %v, want completed", item["status"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.done", + "event_id":"evt_123", + "response":{ + "id":"resp_123", + "output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}] + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var response map[string]any + if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil { + t.Fatalf("json.Unmarshal(response) error = %v", err) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseDone, + ExtraParams: map[string]json.RawMessage{ + "response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + response, ok := payload["response"].(map[string]any) + if !ok { + t.Fatalf("response = %T, want object", payload["response"]) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.content_part.added", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "part":{ + "type":"text", + "text":"hello" + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var part map[string]any + if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil { + t.Fatalf("json.Unmarshal(part) error = %v", err) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseContentPartAdded, + ExtraParams: map[string]json.RawMessage{ + "part": json.RawMessage(`{"type":"text","text":"hello"}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + part, ok := payload["part"].(map[string]any) + if !ok { + t.Fatalf("part = %T, want object", payload["part"]) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + event, err := schemas.ParseRealtimeEvent([]byte(`{ + "type":"session.update", + "session":{ + "type":"realtime", + "model":"gpt-4o-realtime-preview", + "output_modalities":["text"] + } + }`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + if event.Session == nil { + t.Fatal("expected session to be parsed") + } + var outputModalities []string + if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil { + t.Fatalf("json.Unmarshal(output_modalities) error = %v", err) + } + if len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("output_modalities = %v, want [text]", outputModalities) + } +} + +func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + Model: "gpt-4o-realtime-preview", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "output_modalities": json.RawMessage(`["text"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Type string `json:"type"` + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload.Type != "session.update" { + t.Fatalf("type = %q, want %q", payload.Type, "session.update") + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + outputModalities, ok := payload.Session["output_modalities"].([]any) + if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"]) + } +} + +func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "object": json.RawMessage(`"realtime.session"`), + "expires_at": json.RawMessage(`1774614381`), + "client_secret": json.RawMessage(`{"value":"secret"}`), + "modalities": json.RawMessage(`["text","audio"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + for _, key := range []string{"id", "object", "expires_at", "client_secret"} { + if _, ok := payload.Session[key]; ok { + t.Fatalf("session.%s unexpectedly present in session.update payload", key) + } + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + if payload.Session["model"] != "gpt-realtime" { + t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"]) + } +} diff --git a/core/providers/openai/text_test.go b/core/providers/openai/text_test.go index b2dd53ee35..71c2f195a0 100644 --- a/core/providers/openai/text_test.go +++ b/core/providers/openai/text_test.go @@ -51,7 +51,6 @@ func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T) func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/openai/transcription.go b/core/providers/openai/transcription.go index 8ab2305b05..8c2bf112a1 100644 --- a/core/providers/openai/transcription.go +++ b/core/providers/openai/transcription.go @@ -54,63 +54,63 @@ func ParseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return utils.NewBifrostOperationError("failed to create form file", err, providerName) + return utils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(openaiReq.File); err != nil { - return utils.NewBifrostOperationError("failed to write file data", err, providerName) + return utils.NewBifrostOperationError("failed to write file data", err) } // Add model field if err := writer.WriteField("model", openaiReq.Model); err != nil { - return utils.NewBifrostOperationError("failed to write model field", err, providerName) + return utils.NewBifrostOperationError("failed to write model field", err) } // Add optional fields if openaiReq.Language != nil { if err := writer.WriteField("language", *openaiReq.Language); err != nil { - return utils.NewBifrostOperationError("failed to write language field", err, providerName) + return utils.NewBifrostOperationError("failed to write language field", err) } } if openaiReq.Prompt != nil { if err := writer.WriteField("prompt", *openaiReq.Prompt); err != nil { - return utils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return utils.NewBifrostOperationError("failed to write prompt field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return utils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return utils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Temperature != nil { if err := writer.WriteField("temperature", fmt.Sprintf("%g", *openaiReq.Temperature)); err != nil { - return utils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return utils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range openaiReq.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } for _, include := range openaiReq.Include { if err := writer.WriteField("include[]", include); err != nil { - return utils.NewBifrostOperationError("failed to write include field", err, providerName) + return utils.NewBifrostOperationError("failed to write include field", err) } } if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return utils.NewBifrostOperationError("failed to write stream field", err, providerName) + return utils.NewBifrostOperationError("failed to write stream field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return utils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return utils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/videos.go b/core/providers/openai/videos.go index aa4052e029..512306b7c7 100644 --- a/core/providers/openai/videos.go +++ b/core/providers/openai/videos.go @@ -132,30 +132,30 @@ func (req *OpenAIVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *sc func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIVideoGenerationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add prompt field (required) if openaiReq.Prompt == "" { - return providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return providerUtils.NewBifrostOperationError("prompt is required", nil) } if err := writer.WriteField("prompt", openaiReq.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add optional model field if openaiReq.Model != "" { if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } } // Add optional seconds field if openaiReq.Seconds != nil { if err := writer.WriteField("seconds", *openaiReq.Seconds); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seconds field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seconds field", err) } } // Add optional size field if openaiReq.Size != "" { if err := writer.WriteField("size", openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } @@ -196,16 +196,16 @@ func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, opena "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err) } if _, err := part.Write(openaiReq.InputReference); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 55ec63daad..a1e746834c 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -4,7 +4,6 @@ package openrouter import ( "fmt" "net/http" - "slices" "strings" "time" @@ -95,12 +94,12 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // Check for auth errors (401, 403) statusCode := resp.StatusCode() if statusCode == fasthttp.StatusUnauthorized || statusCode == fasthttp.StatusForbidden { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } // Any 4xx/5xx error indicates the key might be invalid if statusCode >= 400 { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } return nil @@ -109,8 +108,6 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Validate the key first using /v1/auth/key (only during provider add/update). // OpenRouter's /v1/models doesn't require auth, so we need this extra check. shouldValidate := false @@ -158,7 +155,7 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Continue with empty response; allowed models will be backfilled below. modelsFetched = false } else { - bifrostErr := openai.ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := openai.ParseOpenAIError(resp) return nil, bifrostErr } } @@ -185,45 +182,62 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, } } - // Filter by key.Models - allowedModels := key.Models - blacklistedModels := key.BlacklistedModels + // OpenRouter model IDs in the API response do NOT include the "openrouter/" prefix + // (e.g. the API returns "openai/gpt-4", not "openrouter/openai/gpt-4"). + // Users may supply allowedModels / aliases with or without the prefix, so we + // normalize both by stripping it before feeding into the shared pipeline. providerPrefix := string(schemas.OpenRouter) + "/" + stripPrefix := func(s string) string { + if strings.HasPrefix(strings.ToLower(s), strings.ToLower(providerPrefix)) { + return s[len(providerPrefix):] + } + return s + } + + normalizedAllowed := make(schemas.WhiteList, 0, len(key.Models)) + for _, m := range key.Models { + normalizedAllowed = append(normalizedAllowed, stripPrefix(m)) + } + normalizedBlacklist := make(schemas.BlackList, 0, len(key.BlacklistedModels)) + for _, m := range key.BlacklistedModels { + normalizedBlacklist = append(normalizedBlacklist, stripPrefix(m)) + } + normalizedAliases := make(map[string]string, len(key.Aliases)) + for k, v := range key.Aliases { + normalizedAliases[stripPrefix(k)] = stripPrefix(v) + } + + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: normalizedAllowed, + BlacklistedModels: normalizedBlacklist, + Aliases: normalizedAliases, + Unfiltered: request.Unfiltered, + ProviderKey: schemas.OpenRouter, + MatchFns: providerUtils.DefaultMatchFns(), + } - if !request.Unfiltered && len(allowedModels) > 0 { + if pipeline.ShouldEarlyExit() { + openrouterResponse.Data = make([]schemas.Model, 0) + } else { + included := make(map[string]bool) filteredData := make([]schemas.Model, 0, len(openrouterResponse.Data)) - includedModels := make(map[string]bool) for i := range openrouterResponse.Data { + // rawID has no "openrouter/" prefix β€” e.g. "openai/gpt-4" rawID := openrouterResponse.Data[i].ID - if !(slices.Contains(allowedModels, rawID) || slices.Contains(allowedModels, providerPrefix+rawID)) { - continue - } - if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { - continue - } - openrouterResponse.Data[i].ID = providerPrefix + rawID - filteredData = append(filteredData, openrouterResponse.Data[i]) - includedModels[rawID] = true - } - // Backfill allowed models not in the API response - for _, allowedModel := range allowedModels { - rawID := strings.TrimPrefix(allowedModel, providerPrefix) - if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { - continue - } - if !includedModels[rawID] { - filteredData = append(filteredData, schemas.Model{ - ID: providerPrefix + rawID, - Name: schemas.Ptr(rawID), - }) - includedModels[rawID] = true // avoid duplicate backfill + for _, result := range pipeline.FilterModel(rawID) { + entry := openrouterResponse.Data[i] + entry.ID = providerPrefix + result.ResolvedID + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } else { + entry.Alias = nil + } + filteredData = append(filteredData, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + filteredData = append(filteredData, pipeline.BackfillModels(included)...) openrouterResponse.Data = filteredData - } else { - for i := range openrouterResponse.Data { - openrouterResponse.Data[i].ID = providerPrefix + openrouterResponse.Data[i].ID - } } openrouterResponse.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/openrouter/openrouter_test.go b/core/providers/openrouter/openrouter_test.go index 2847aeaf7c..1b3510d0d2 100644 --- a/core/providers/openrouter/openrouter_test.go +++ b/core/providers/openrouter/openrouter_test.go @@ -31,25 +31,25 @@ func TestOpenRouter(t *testing.T) { EmbeddingModel: "", ReasoningModel: "openai/gpt-oss-120b", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: false, // OpenRouter's responses API is in Beta + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: false, // OpenRouter's responses API is in Beta MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, // OpenRouter's responses API is in Beta - ImageBase64: false, // OpenRouter's responses API is in Beta - MultipleImages: false, // OpenRouter's responses API is in Beta - FileBase64: true, - FileURL: true, - CompleteEnd2End: false, // OpenRouter's responses API is in Beta - Reasoning: true, - ListModels: true, - StructuredOutputs: true, // Structured outputs with nullable enum support + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, // OpenRouter's responses API is in Beta + ImageBase64: false, // OpenRouter's responses API is in Beta + MultipleImages: false, // OpenRouter's responses API is in Beta + FileBase64: true, + FileURL: true, + CompleteEnd2End: false, // OpenRouter's responses API is in Beta + Reasoning: true, + ListModels: true, + StructuredOutputs: true, // Structured outputs with nullable enum support }, } diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index 6dc6b74cca..280bf0b64d 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -145,9 +145,6 @@ func (provider *ParasailProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/perplexity/chat.go b/core/providers/perplexity/chat.go index dafe1c615b..a832ac0ad7 100644 --- a/core/providers/perplexity/chat.go +++ b/core/providers/perplexity/chat.go @@ -280,8 +280,6 @@ func (response *PerplexityChatResponse) ToBifrostChatResponse(model string) *sch Object: response.Object, Created: response.Created, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Perplexity, }, SearchResults: response.SearchResults, Videos: response.Videos, diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index d0e68a6850..2b0abbefb4 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -101,12 +101,12 @@ func (provider *PerplexityProvider) completeRequest(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) - return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, schemas.ChatCompletionRequest, provider.GetProviderKey(), model) + return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -141,8 +141,7 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToPerplexityChatCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -161,9 +160,6 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -223,9 +219,6 @@ func (provider *PerplexityProvider) Responses(ctx *schemas.BifrostContext, key s } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/replicate/errors.go b/core/providers/replicate/errors.go index e7fc2051d0..1575d9ca77 100644 --- a/core/providers/replicate/errors.go +++ b/core/providers/replicate/errors.go @@ -15,9 +15,6 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: replicateErr.Detail, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -28,8 +25,5 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: string(body), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/replicate/files.go b/core/providers/replicate/files.go index cdd37a65c5..15ca0e13e8 100644 --- a/core/providers/replicate/files.go +++ b/core/providers/replicate/files.go @@ -30,8 +30,6 @@ func (r *ReplicateFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } @@ -67,8 +65,6 @@ func (r *ReplicateFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } diff --git a/core/providers/replicate/images.go b/core/providers/replicate/images.go index 4fa7d0dd81..0327b8b5fa 100644 --- a/core/providers/replicate/images.go +++ b/core/providers/replicate/images.go @@ -1,7 +1,6 @@ package replicate import ( - "strconv" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -28,49 +27,6 @@ var modelInputImageFieldMap = map[string]string{ "black-forest-labs/flux-krea-dev": "image", } -// convertSizeToReplicateFormat converts standard size format (e.g., "1024x1024") to Replicate format. -// Returns (aspectRatio, imageSize) where imageSize is "1k", "2k", "4k" and aspectRatio is one of: -// "1:1", "3:4", "4:3", "9:16", or "16:9". Returns empty strings if unparseable or ratio unrecognised. -func convertSizeToReplicateFormat(size string) (aspectRatio, imageSize string) { - parts := strings.Split(size, "x") - if len(parts) != 2 { - return "", "" - } - - width, err1 := strconv.Atoi(parts[0]) - height, err2 := strconv.Atoi(parts[1]) - if err1 != nil || err2 != nil { - return "", "" - } - - if width <= 0 || height <= 0 { - return "", "" - } - - if width <= 1024 && height <= 1024 { - imageSize = "1K" - } else if width <= 2048 && height <= 2048 { - imageSize = "2K" - } else if width <= 4096 && height <= 4096 { - imageSize = "4K" - } - - ratio := float64(width) / float64(height) - if ratio >= 0.99 && ratio <= 1.01 { - aspectRatio = "1:1" - } else if ratio >= 0.74 && ratio <= 0.76 { - aspectRatio = "3:4" - } else if ratio >= 1.32 && ratio <= 1.34 { - aspectRatio = "4:3" - } else if ratio >= 0.56 && ratio <= 0.57 { - aspectRatio = "9:16" - } else if ratio >= 1.77 && ratio <= 1.78 { - aspectRatio = "16:9" - } - - return aspectRatio, imageSize -} - // ToReplicateImageGenerationInput converts a Bifrost image generation request to Replicate prediction input func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationRequest) *ReplicatePredictionRequest { if bifrostReq == nil || bifrostReq.Input == nil { @@ -85,29 +41,6 @@ func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationR if bifrostReq.Params != nil { params := bifrostReq.Params - // Map InputImages to the appropriate field based on model - if len(params.InputImages) > 0 { - fieldName := getInputImageFieldName(bifrostReq.Model) - - switch fieldName { - case "image_prompt": - // For flux-1.1-pro variants: use first image as image_prompt - input.ImagePrompt = ¶ms.InputImages[0] - - case "input_image": - // For flux-kontext variants: add to ExtraParams as input_image - input.InputImage = ¶ms.InputImages[0] - - case "image": - // For flux-dev variants: use first image as image field - input.Image = ¶ms.InputImages[0] - - case "input_images": - // For all other models: use input_images array - input.InputImages = params.InputImages - } - } - if bifrostReq.Params.N != nil { input.NumberOfImages = bifrostReq.Params.N } @@ -117,7 +50,7 @@ func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationR } if params.Size != nil { - aspectRatio, imageSize := convertSizeToReplicateFormat(*params.Size) + aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size) _, hasExplicitResolution := params.ExtraParams["resolution"] if params.AspectRatio == nil && aspectRatio != "" { input.AspectRatio = &aspectRatio @@ -191,9 +124,6 @@ func ToBifrostImageGenerationResponse( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -294,7 +224,7 @@ func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *Rep input.Image = &images[0] case "input_images": - // For all other models: use input_images array + // For all other models: use input_images array (preserves multi-image support) input.InputImages = images } } @@ -309,7 +239,7 @@ func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *Rep } if params.Size != nil { - aspectRatio, imageSize := convertSizeToReplicateFormat(*params.Size) + aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size) _, hasExplicitAspectRatio := params.ExtraParams["aspect_ratio"] _, hasExplicitResolution := params.ExtraParams["resolution"] if aspectRatio != "" && !hasExplicitAspectRatio { diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 3989628db1..6c0c14dbf7 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -1,61 +1,66 @@ package replicate import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToBifrostListModelsResponse converts Replicate models and deployments to a Bifrost list models response +// ToBifrostListModelsResponse converts Replicate deployments to a Bifrost list models response. +// Replicate model IDs are composite: "{owner}/{name}" (e.g. "stability-ai/stable-diffusion"). func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, - allowedModels []string, - blacklistedModels []string, + allowedModels schemas.WhiteList, + blacklistedModels schemas.BlackList, + aliases map[string]string, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } - includedModels := make(map[string]bool) - // Add deployments from /v1/deployments endpoint + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) + if deploymentsResponse != nil { for _, deployment := range deploymentsResponse.Results { + // Replicate model IDs are composite owner/name deploymentID := deployment.Owner + "/" + deployment.Name - modelName := schemas.Ptr(deployment.Name) var created *int64 - - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, deploymentID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, deploymentID) { - continue - } - - // Extract information from current release if available - if deployment.CurrentRelease != nil { - // Parse created timestamp - if deployment.CurrentRelease.CreatedAt != "" { - createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) - if createdTimestamp > 0 { - created = schemas.Ptr(createdTimestamp) - } + if deployment.CurrentRelease != nil && deployment.CurrentRelease.CreatedAt != "" { + createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) + if createdTimestamp > 0 { + created = schemas.Ptr(createdTimestamp) } } - bifrostModel := schemas.Model{ - ID: string(providerKey) + "/" + deploymentID, - Name: modelName, - Deployment: modelName, - OwnedBy: schemas.Ptr(deployment.Owner), - Created: created, + for _, result := range pipeline.FilterModel(deploymentID) { + bifrostModel := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(deployment.Name), + OwnedBy: schemas.Ptr(deployment.Owner), + Created: created, + } + if result.AliasValue != "" { + bifrostModel.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) - includedModels[deploymentID] = true } if deploymentsResponse.Next != nil { @@ -63,58 +68,8 @@ func ToBifrostListModelsResponse( } } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue - } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - } - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) return bifrostResponse } - -// ToReplicateListModelsResponse converts a Bifrost list models response to a Replicate list models response -// This is mainly used for testing and compatibility -func ToReplicateListModelsResponse(response *schemas.BifrostListModelsResponse) *ReplicateModelListResponse { - if response == nil { - return nil - } - - replicateResponse := &ReplicateModelListResponse{ - Results: make([]ReplicateModelResponse, 0, len(response.Data)), - } - - for _, model := range response.Data { - modelID := strings.TrimPrefix(model.ID, string(schemas.Replicate)+"/") - replicateModel := ReplicateModelResponse{ - URL: "https://replicate.com/" + modelID, - Name: modelID, - } - - if model.Description != nil { - replicateModel.Description = model.Description - } - - if model.OwnedBy != nil { - replicateModel.Owner = *model.OwnedBy - } - - replicateResponse.Results = append(replicateResponse.Results, replicateModel) - } - - // Set next page token if available - if response.NextPageToken != "" { - next := response.NextPageToken - replicateResponse.Next = &next - } - - return replicateResponse -} diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 012eaa845e..2fa9c81ab5 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -149,7 +149,7 @@ func createPrediction( // Parse response body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, schemas.Replicate) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var prediction ReplicatePredictionResponse @@ -204,7 +204,7 @@ func getPrediction( // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, schemas.Replicate) + return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } prediction := &ReplicatePredictionResponse{} @@ -252,9 +252,7 @@ func pollPrediction( case <-pollCtx.Done(): return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError( schemas.ErrProviderRequestTimedOut, - fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds), - schemas.Replicate, - ) + fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds)) case <-ticker.C: prediction, rawResponse, providerResponseHeaders, err = getPrediction(pollCtx, client, predictionURL, key, logger, sendBackRawResponse) if err != nil { @@ -277,6 +275,17 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont client := provider.client extraHeaders := provider.networkConfig.ExtraHeaders + if key.ReplicateKeyConfig == nil || !key.ReplicateKeyConfig.UseDeploymentsEndpoint { + return ToBifrostListModelsResponse( + &ReplicateDeploymentListResponse{}, + providerName, + key.Models, + key.BlacklistedModels, + key.Aliases, + request.Unfiltered, + ), nil + } + // Build deployments URL deploymentsURL := provider.buildRequestURL(ctx, "/v1/deployments", schemas.ListModelsRequest) @@ -335,9 +344,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont if err := sonic.Unmarshal(bodyCopy, &pageResponse); err != nil { return nil, providerUtils.NewBifrostOperationError( "failed to parse deployments response", - err, - schemas.Replicate, - ) + err) } // Append results from this page @@ -362,6 +369,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont providerName, key.Models, key.BlacklistedModels, + key.Aliases, request.Unfiltered, ) @@ -375,11 +383,10 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys } if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } startTime := time.Now() - providerName := provider.GetProviderKey() response, err := providerUtils.HandleMultipleListModelsRequests( ctx, @@ -393,8 +400,6 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys // Update metadata with total latency latency := time.Since(startTime) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -406,17 +411,11 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -431,7 +430,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.TextCompletionRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // create prediction @@ -480,10 +479,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostTextCompletionResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -503,11 +499,6 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -519,8 +510,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -532,7 +522,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.TextCompletionStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -556,9 +546,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -589,9 +577,9 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -636,7 +624,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -667,11 +655,8 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -705,14 +690,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -727,14 +705,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -750,10 +721,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont nil, // usage - not available in done event finishReason, chunkIndex, - schemas.TextCompletionStreamRequest, - provider.GetProviderKey(), - request.Model, - ) + schemas.TextCompletionStreamRequest) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -781,17 +749,11 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -806,7 +768,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.ChatCompletionRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // create prediction @@ -855,10 +817,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostChatResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -878,11 +837,6 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -894,8 +848,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -907,7 +860,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.ChatCompletionStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -931,9 +884,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -964,9 +915,9 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1011,7 +962,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -1049,11 +1000,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1087,14 +1035,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1109,14 +1050,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1142,11 +1076,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1174,17 +1105,11 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -1199,7 +1124,7 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ResponsesRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // create prediction @@ -1246,9 +1171,6 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc // Convert to Bifrost response response := prediction.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1266,24 +1188,18 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } // Enable streaming (using sjson to set field directly, preserving key order) if updatedData, err := providerUtils.SetJSONField(jsonData, "stream", true); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err) } else { jsonData = updatedData } @@ -1295,7 +1211,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ResponsesStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -1319,9 +1235,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1360,9 +1274,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(streamErr, fasthttp.ErrTimeout) || errors.Is(streamErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1391,9 +1305,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1405,10 +1319,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if reader == nil { bifrostErr := providerUtils.NewBifrostOperationError( - "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - provider.GetProviderKey(), - ) + "provider returned an empty response", + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger) return @@ -1455,7 +1367,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()) + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1497,11 +1409,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } if sendBackRawRequest { @@ -1524,10 +1433,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1556,10 +1462,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1587,10 +1490,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1610,10 +1510,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, Delta: schemas.Ptr(currentEvent.Data), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1639,10 +1536,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, ItemID: schemas.Ptr(itemID), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1665,10 +1559,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1702,10 +1593,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1725,11 +1613,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CompletedAt: schemas.Ptr(int(time.Now().Unix())), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } @@ -1762,14 +1647,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream error: %s", errorMsg), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } + fmt.Errorf("stream error: %s", errorMsg)) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1825,19 +1703,13 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageGenerationInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1852,7 +1724,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageGenerationRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction with appropriate mode @@ -1904,10 +1776,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -1926,15 +1795,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -1943,8 +1806,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon replicateReq := ToReplicateImageGenerationInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1956,7 +1818,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon request.Model, provider.customProviderConfig, schemas.ImageGenerationStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction prediction, _, _, _, err := createPrediction( @@ -1977,10 +1839,16 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2011,9 +1879,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2060,7 +1928,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading SSE stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, provider.logger) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2105,11 +1974,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2143,36 +2009,24 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended with error")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2187,11 +2041,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon OutputFormat: lastOutputFormat, // Include output format CreatedAt: time.Now().Unix(), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2233,17 +2084,13 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon Error: &schemas.ErrorField{ Message: errorMsg, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Include accumulated raw responses in error if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2260,19 +2107,13 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageEditInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2287,7 +2128,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ImageEditRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction with appropriate mode @@ -2339,10 +2180,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2361,15 +2199,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -2378,8 +2210,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, replicateReq := ToReplicateImageEditInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2391,7 +2222,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageEditStreamRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction @@ -2413,10 +2244,16 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2447,9 +2284,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2494,18 +2331,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, if errors.Is(readErr, context.Canceled) { return } - bifrostErr := providerUtils.NewBifrostOperationError( - "stream read error", - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("stream read error", readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2548,11 +2376,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2586,34 +2411,22 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended with error")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2628,11 +2441,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: lastOutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2660,18 +2470,12 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, bifrostErr := providerUtils.NewBifrostOperationError( "stream error", - fmt.Errorf("%s", errorData.Detail), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("%s", errorData.Detail)) if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2693,21 +2497,13 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - - providerName := provider.GetProviderKey() - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateVideoGenerationInput(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2719,7 +2515,7 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.VideoGenerationRequest, - isDeployment, + key.ReplicateKeyConfig.UseDeploymentsEndpoint, ) // Create prediction with appropriate mode @@ -2748,13 +2544,10 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, if err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) + bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, schemas.Replicate) // Set extra fields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2774,7 +2567,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -2816,7 +2609,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -2828,12 +2621,10 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke bifrostResponse, convertErr := ToBifrostVideoGenerationResponse(&prediction) if convertErr != nil { - return nil, providerUtils.EnrichError(ctx, convertErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, convertErr, nil, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if sendBackRawResponse { @@ -2848,9 +2639,8 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve latest status/output first. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2864,19 +2654,17 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if videoResp.Videos[0].URL != nil { videoUrl = *videoResp.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2896,9 +2684,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2906,7 +2692,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -2920,8 +2706,6 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2977,7 +2761,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -3014,22 +2798,22 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s part, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } // Add filename field if provided if filename != "" { if err := writer.WriteField("filename", filename); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err) } } // Add type field (content type) if err := writer.WriteField("type", contentType); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write type field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write type field", err) } // Add metadata field if provided @@ -3038,24 +2822,24 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s if len(metadata) > 0 { metadataJSON, err := providerUtils.MarshalSorted(metadata) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } h := make(textproto.MIMEHeader) h.Set("Content-Disposition", `form-data; name="metadata"`) h.Set("Content-Type", "application/json") metadataPart, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err) } if _, err := metadataPart.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } } } } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3091,7 +2875,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var replicateResp ReplicateFileResponse @@ -3119,7 +2903,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper (Replicate uses cursor-based pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3130,10 +2914,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3182,7 +2962,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var replicateResp ReplicateFileListResponse @@ -3226,8 +3006,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: finalHasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3244,7 +3022,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3289,7 +3067,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3321,7 +3099,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3364,8 +3142,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3386,7 +3162,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3411,8 +3187,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, diff --git a/core/providers/replicate/replicate_test.go b/core/providers/replicate/replicate_test.go index c6f72cfded..a9179f963c 100644 --- a/core/providers/replicate/replicate_test.go +++ b/core/providers/replicate/replicate_test.go @@ -459,187 +459,6 @@ func TestBifrostToReplicateImageGenerationConversion(t *testing.T) { validate func(t *testing.T, result *replicate.ReplicatePredictionRequest) wantErr bool }{ - { - name: "Flux_1_1_Pro_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Flux 1.1 Pro should use ImagePrompt field - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.Image) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_1_1_Pro_Ultra_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro-ultra", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, - { - name: "Flux_Pro_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, - { - name: "Flux_Kontext_Pro_InputImage", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-kontext-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Kontext models should use InputImage field - assert.NotNil(t, result.Input.InputImage) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.InputImage) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.Image) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_Kontext_Max_InputImage", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-kontext-max", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.InputImage) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.InputImage) - }, - }, - { - name: "Flux_Dev_Image", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-dev", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Flux Dev should use Image field - assert.NotNil(t, result.Input.Image) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.Image) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_Fill_Pro_Image", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-fill-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.Image) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.Image) - }, - }, - { - name: "Other_Model_InputImages", - input: &schemas.BifrostImageGenerationRequest{ - Model: "stability-ai/sdxl", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input1.jpg", "https://example.com/input2.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Other models should use InputImages array - assert.NotNil(t, result.Input.InputImages) - assert.Len(t, result.Input.InputImages, 2) - assert.Equal(t, "https://example.com/input1.jpg", result.Input.InputImages[0]) - assert.Equal(t, "https://example.com/input2.jpg", result.Input.InputImages[1]) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.Image) - }, - }, - { - name: "Model_With_Version", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro:v1.0", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Should still match flux-1.1-pro and use ImagePrompt - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, { name: "AllParameters", input: &schemas.BifrostImageGenerationRequest{ diff --git a/core/providers/replicate/utils.go b/core/providers/replicate/utils.go index 3279b0a847..1d88337539 100644 --- a/core/providers/replicate/utils.go +++ b/core/providers/replicate/utils.go @@ -31,17 +31,13 @@ func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.Bifro } return providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("%s", errorMsg), - schemas.Replicate, - ) + fmt.Errorf("%s", errorMsg)) } if prediction.Status == ReplicatePredictionStatusCanceled { return providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("prediction was canceled"), - schemas.Replicate, - ) + fmt.Errorf("prediction was canceled")) } return nil @@ -126,9 +122,9 @@ func listenToReplicateStreamURL( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -178,24 +174,12 @@ func isVersionID(s string) bool { return versionIDPattern.MatchString(s) } -// resolveDeploymentModel checks if the model maps to a deployment. -// Returns the resolved model and whether it is a deployment. -func resolveDeploymentModel(model string, key schemas.Key) (string, bool) { - if key.ReplicateKeyConfig == nil || key.ReplicateKeyConfig.Deployments == nil { - return model, false - } - if deployment, ok := key.ReplicateKeyConfig.Deployments[model]; ok && deployment != "" { - return deployment, true - } - return model, false -} - // buildPredictionURL builds the appropriate URL for creating a prediction // Returns the URL for the appropriate prediction endpoint. -func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, isDeployment bool) string { +func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string { var defaultPath string - if isDeployment { + if useDeploymentsEndpoint { defaultPath = "/v1/deployments/" + model + "/predictions" } else if isVersionID(model) { // If model is a version ID, use base predictions endpoint diff --git a/core/providers/replicate/videos.go b/core/providers/replicate/videos.go index b6dadaab55..3a277d067d 100644 --- a/core/providers/replicate/videos.go +++ b/core/providers/replicate/videos.go @@ -87,9 +87,6 @@ func ToBifrostVideoGenerationResponse(prediction *ReplicatePredictionResponse) ( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/runway/errors.go b/core/providers/runway/errors.go index a64f8ffc60..d9259e825f 100644 --- a/core/providers/runway/errors.go +++ b/core/providers/runway/errors.go @@ -9,7 +9,7 @@ import ( ) // parseRunwayError parses Runway API error responses and converts them to BifrostError. -func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseRunwayError(resp *fasthttp.Response) *schemas.BifrostError { // Parse as RunwayAPIError var errorResp RunwayAPIError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -34,12 +34,5 @@ func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Message = strings.TrimRight(bifrostErr.Error.Message, "\n") } - // Set metadata - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } - return bifrostErr } diff --git a/core/providers/runway/runway.go b/core/providers/runway/runway.go index d512742afd..a4d95c2bf8 100644 --- a/core/providers/runway/runway.go +++ b/core/providers/runway/runway.go @@ -165,8 +165,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostReq, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToRunwayVideoGenerationRequest(bifrostReq) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -205,17 +204,14 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -232,10 +228,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key Object: "video", Status: schemas.VideoStatusQueued, ExtraFields: schemas.BifrostResponseExtraFields{ - Latency: latency.Milliseconds(), - Provider: providerName, - ModelRequested: model, - RequestType: schemas.VideoGenerationRequest, + Latency: latency.Milliseconds(), }, } @@ -282,16 +275,14 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -309,8 +300,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if sendBackRawRequest { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -324,7 +313,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // VideoDownload retrieves a video from Runway's API. func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() // Retrieve task status to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ Provider: request.Provider, @@ -338,20 +326,21 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if taskDetails.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", taskDetails.Status), - nil, - providerName, - ) + nil) } if len(taskDetails.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if taskDetails.Videos[0].URL != nil { videoUrl = *taskDetails.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + // Download video from Runway's URL req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -367,14 +356,13 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } // Get content and content type body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -389,8 +377,6 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -402,7 +388,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("task_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("task_id is required", nil) } taskID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -434,10 +420,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch // Handle error response - Runway returns 204 No Content on success if resp.StatusCode() != fasthttp.StatusNoContent { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoDeleteRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Build response - Runway returns empty body on 204 @@ -448,8 +431,6 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch } response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.VideoDeleteRequest return response, nil } diff --git a/core/providers/runway/videos.go b/core/providers/runway/videos.go index 49b0cec237..809a8a1038 100644 --- a/core/providers/runway/videos.go +++ b/core/providers/runway/videos.go @@ -121,7 +121,7 @@ func ToRunwayVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Runway task details to Bifrost video generation response format. func ToBifrostVideoGenerationResponse(taskDetails *RunwayTaskDetailsResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if taskDetails == nil { - return nil, providerUtils.NewBifrostOperationError("task details is nil", nil, schemas.Runway) + return nil, providerUtils.NewBifrostOperationError("task details is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index 68a212f6b7..25a1375d06 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -3,7 +3,6 @@ package sgl import ( - "fmt" "strings" "time" @@ -50,11 +49,7 @@ func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGL client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for SGLang - if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for sgl provider") - } - + // BaseURL is optional when keys have sgl_key_config with per-key URLs return &SGLProvider{ logger: logger, client: client, @@ -69,27 +64,40 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { return schemas.SGL } -// ListModels performs a list models request to SGL's API. -func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - return openai.HandleOpenAIListModelsRequest( +// listModelsByKey performs a list models request for a single SGL key, +// resolving the per-key URL so each backend is queried individually. +func (provider *SGLProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return openai.ListModelsByKey( ctx, provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), + key, + request.Unfiltered, provider.networkConfig.ExtraHeaders, - schemas.SGL, + provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), ) } -// TextCompletion is not supported by the SGL provider. +// ListModels performs a list models request to SGL's API. +// Requests are made concurrently per key so that each backend is queried +// with its own URL (from sgl_key_config). +func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return providerUtils.HandleMultipleListModelsRequests( + ctx, + keys, + request, + provider.listModelsByKey, + ) +} + +// TextCompletion performs a text completion request to the SGL API. func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -109,7 +117,7 @@ func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, p return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/completions", + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -129,7 +137,7 @@ func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -151,7 +159,7 @@ func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, p return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -176,9 +184,6 @@ func (provider *SGLProvider) Responses(ctx *schemas.BifrostContext, key schemas. } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -194,12 +199,12 @@ func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo ) } -// Embedding is not supported by the SGL provider. +// Embedding performs an embedding request to the SGL API. func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/sgl/sgl_test.go b/core/providers/sgl/sgl_test.go index 11447f58b4..20236182fc 100644 --- a/core/providers/sgl/sgl_test.go +++ b/core/providers/sgl/sgl_test.go @@ -29,22 +29,22 @@ func TestSGL(t *testing.T) { TextModel: "qwen/qwen2.5-0.5b-instruct", EmbeddingModel: "Alibaba-NLP/gte-Qwen2-1.5B-instruct", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + Embedding: true, + ListModels: true, }, } diff --git a/core/providers/utils/images.go b/core/providers/utils/images.go new file mode 100644 index 0000000000..12bc01c8ba --- /dev/null +++ b/core/providers/utils/images.go @@ -0,0 +1,50 @@ +package utils + +import ( + "strconv" + "strings" +) + +// ConvertSizeToAspectRatioAndResolution converts a standard size string (e.g., "1024x1024") +// to an aspect ratio and image size tier. +// aspectRatio is one of "1:1", "3:4", "4:3", "9:16", "16:9" (empty if unrecognised). +// imageSize is one of "1K", "2K", "4K" (empty if out of range). +func ConvertSizeToAspectRatioAndResolution(size string) (aspectRatio, imageSize string) { + parts := strings.Split(size, "x") + if len(parts) != 2 { + return "", "" + } + + width, err1 := strconv.Atoi(parts[0]) + height, err2 := strconv.Atoi(parts[1]) + if err1 != nil || err2 != nil { + return "", "" + } + + if width <= 0 || height <= 0 { + return "", "" + } + + if width <= 1024 && height <= 1024 { + imageSize = "1K" + } else if width <= 2048 && height <= 2048 { + imageSize = "2K" + } else if width <= 4096 && height <= 4096 { + imageSize = "4K" + } + + ratio := float64(width) / float64(height) + if ratio >= 0.99 && ratio <= 1.01 { + aspectRatio = "1:1" + } else if ratio >= 0.74 && ratio <= 0.76 { + aspectRatio = "3:4" + } else if ratio >= 1.32 && ratio <= 1.34 { + aspectRatio = "4:3" + } else if ratio >= 0.56 && ratio <= 0.57 { + aspectRatio = "9:16" + } else if ratio >= 1.77 && ratio <= 1.78 { + aspectRatio = "16:9" + } + + return aspectRatio, imageSize +} diff --git a/core/providers/utils/large_response.go b/core/providers/utils/large_response.go index a7e0e7bf36..e62d375c9a 100644 --- a/core/providers/utils/large_response.go +++ b/core/providers/utils/large_response.go @@ -116,7 +116,6 @@ func MaterializeStreamErrorBody(ctx *schemas.BifrostContext, resp *fasthttp.Resp func FinalizeResponseWithLargeDetection( ctx *schemas.BifrostContext, resp *fasthttp.Response, - providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, bool, *schemas.BifrostError) { responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) @@ -125,7 +124,7 @@ func FinalizeResponseWithLargeDetection( if responseThreshold <= 0 { body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Copy body before caller releases resp return append([]byte(nil), body...), false, nil @@ -142,14 +141,14 @@ func FinalizeResponseWithLargeDetection( } bodyBytes, readErr := io.ReadAll(reader) if readErr != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } return bodyBytes, false, nil } // No stream β€” buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -169,7 +168,7 @@ func FinalizeResponseWithLargeDetection( bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, responseThreshold+1)) if readErr != nil { releaseGzip() - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } if int64(len(bodyBytes)) <= responseThreshold { releaseGzip() @@ -195,7 +194,7 @@ func FinalizeResponseWithLargeDetection( // No stream β€” buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -206,11 +205,11 @@ func FinalizeResponseWithLargeDetection( if bodyStream == nil { // No stream available β€” fall back to buffered read if logger != nil { - logger.Warn("large-response fallback to buffered path: provider=%s content_length=%d threshold=%d body_stream_nil=true", providerName, contentLength, responseThreshold) + logger.Warn("large-response fallback to buffered path: content_length=%d threshold=%d body_stream_nil=true", contentLength, responseThreshold) } body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -232,7 +231,7 @@ func FinalizeResponseWithLargeDetection( if wasGzip { ReleaseGzipReader(gz) } - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } prefetchBuf = prefetchBuf[:n] diff --git a/core/providers/utils/make_request_test.go b/core/providers/utils/make_request_test.go index ce1610d7bb..ec2bf771bc 100644 --- a/core/providers/utils/make_request_test.go +++ b/core/providers/utils/make_request_test.go @@ -295,7 +295,7 @@ func TestMakeRequestWithContext_ConcurrentRequestsWithCancellation(t *testing.T) } func TestNewBifrostTimeoutError(t *testing.T) { - err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded, "openai") + err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded) if !err.IsBifrostError { t.Fatal("expected IsBifrostError to be true") diff --git a/core/providers/utils/models.go b/core/providers/utils/models.go new file mode 100644 index 0000000000..f1b3d0351b --- /dev/null +++ b/core/providers/utils/models.go @@ -0,0 +1,356 @@ +// Package utils β€” list_models.go +// Centralised pipeline for filtering and backfilling models in ListModels responses. +// +// Every provider's ToBifrostListModelsResponse follows the same logical steps: +// 1. Resolve each API model's name (alias lookup β†’ alias key; else raw model ID) +// 2. Filter (allowlist + blacklist check on the resolved name) +// 3. Backfill entries that were not returned by the API but should appear in output +// +// Providers plug in custom MatchFns to extend the default matching behaviour. +// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns. +package utils + +import ( + "sort" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// ToDisplayName converts a raw model ID or alias key into a human-readable display name. +// Splits on "-" or "_", title-cases each word, and joins with spaces. +// +// "gemini-pro" β†’ "Gemini Pro" +// "claude_3_opus" β†’ "Claude 3 Opus" +// "gpt-4-turbo" β†’ "Gpt 4 Turbo" +func ToDisplayName(id string) string { + caser := cases.Title(language.English) + parts := strings.FieldsFunc(id, func(r rune) bool { + return r == '-' || r == '_' + }) + if len(parts) == 0 { + return "" + } + for i, part := range parts { + if part != "" { + parts[i] = caser.String(strings.ToLower(part)) + } + } + return strings.Join(parts, " ") +} + +// MatchFn reports whether two model ID strings should be treated as equivalent. +// Functions are applied in order during every comparison β€” the first one that +// returns true short-circuits the rest. +// +// Example built-in fns (see DefaultMatchFns): +// +// exactMatch("gpt-4", "gpt-4") β†’ true +// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") β†’ true +type MatchFn func(a, b string) bool + +// DefaultMatchFns returns the standard matching functions used by most providers. +// Currently only performs case-insensitive exact matching. +// +// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" β‰ˆ "claude-3-5-sonnet") +// is intentionally excluded β€” users should use aliases for explicit version-to-base-name mapping. +// It can be appended here if fuzzy base-model matching is ever needed globally. +func DefaultMatchFns() []MatchFn { + return []MatchFn{ + func(a, b string) bool { return strings.EqualFold(a, b) }, + } +} + +// matches reports whether a and b are considered equal by any of the provided fns. +// Returns true on the first fn that returns true. +func matches(a, b string, fns []MatchFn) bool { + for _, fn := range fns { + if fn(a, b) { + return true + } + } + return false +} + +// FilterResult is the outcome of running Pipeline.FilterModel for a single model +// from the provider's API response. Each returned result represents one alias +// entry (or the raw model ID when no alias matched) that passed all filters. +type FilterResult struct { + // ResolvedID is the user-facing model name to use as the ID suffix. + // If the model matched an alias VALUE, this is the alias KEY. + // Otherwise this is the original model ID from the API response. + // + // Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"} + // β†’ ResolvedID = "my-gpt4" + // Example: API returns "gpt-3.5-turbo", no alias match + // β†’ ResolvedID = "gpt-3.5-turbo" + ResolvedID string + + // AliasValue is the provider-specific model ID when the model was matched + // via an alias. Set as the model.Alias field so callers know the underlying ID. + // Empty when the model was matched directly (no alias involved). + // + // Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched + // β†’ AliasValue = "gpt-4-turbo" + AliasValue string +} + +// Pipeline holds all the context needed to filter and backfill models in a +// single ListModels response. Construct one per ToBifrostListModelsResponse call +// and use its methods instead of passing params + matchFns to every function. +// +// pipeline := &providerUtils.ListModelsPipeline{ +// AllowedModels: key.Models, +// BlacklistedModels: key.BlacklistedModels, +// Aliases: key.Aliases, +// Unfiltered: request.Unfiltered, +// ProviderKey: schemas.OpenAI, +// MatchFns: providerUtils.DefaultMatchFns(), +// } +// if pipeline.ShouldEarlyExit() { return empty } +// result := pipeline.FilterModel(model.ID) +// pipeline.BackfillModels(included) +type ListModelsPipeline struct { + AllowedModels schemas.WhiteList + BlacklistedModels schemas.BlackList + // Aliases maps user-facing alias keys to provider-specific model IDs. + // e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"} + Aliases map[string]string + Unfiltered bool + ProviderKey schemas.ModelProvider + // MatchFns is the ordered list of equivalence functions used for every + // model ID comparison. Use DefaultMatchFns() for standard behaviour; + // providers may append additional fns (e.g. Bedrock's region-prefix remover). + MatchFns []MatchFn +} + +// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately +// return an empty response without processing any models. +// +// Returns true when: +// - not unfiltered AND allowlist is empty AND no aliases configured +// (there is nothing to match against β€” all models would be filtered out anyway) +// - not unfiltered AND blacklist blocks everything +// +// Note: allowlist empty + aliases present β†’ do NOT early exit. +// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels). +func (p *ListModelsPipeline) ShouldEarlyExit() bool { + if p.Unfiltered { + return false + } + if p.BlacklistedModels.IsBlockAll() { + return true + } + if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 { + return true + } + return false +} + +// aliasMatch holds a single alias key/value pair returned by resolveModelID. +type aliasMatch struct { + key string + value string +} + +// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns, +// plus the raw model ID itself as an additional entry so both the alias key and the original model +// name appear in the list-models output. +// Results are sorted by alias key (case-insensitive) for deterministic ordering. +// +// If one or more aliases match β†’ returns one aliasMatch per matching alias key, plus the raw ID. +// +// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"} +// β†’ [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}] +// +// If no alias matches β†’ returns a single entry with the original model ID and no alias value. +// +// Example: modelID="gpt-3.5-turbo", no alias match +// β†’ [{key:"gpt-3.5-turbo", value:""}] +func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch { + var candidates []aliasMatch + for aliasKey, providerID := range p.Aliases { + if matches(modelID, providerID, p.MatchFns) { + candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID}) + } + } + if len(candidates) == 0 { + return []aliasMatch{{key: modelID, value: ""}} + } + // Also include the raw model ID so both the alias key and the original name appear in output. + candidates = append(candidates, aliasMatch{key: modelID, value: ""}) + sort.Slice(candidates, func(i, j int) bool { + return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key) + }) + return candidates +} + +// FilterModel applies the full filter pipeline for a single model from the API response. +// +// Steps: +// 1. Resolve name β€” check alias VALUES for a match (uses MatchFns). +// If matched: resolvedName = alias KEY, aliasValue = provider ID. +// If not matched: resolvedName = original modelID, aliasValue = "". +// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard): +// Skip if resolvedName is not in AllowedModels. +// 3. Blacklist check (always): +// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything. +// 4. Return one FilterResult per passing candidate. +// +// An empty slice means the model should be skipped entirely. +// When multiple aliases map to the same provider model ID, each alias that passes +// the filters produces its own FilterResult entry. +// +// Examples: +// +// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") β†’ [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// FilterModel("gpt-3.5") β†’ [] (not in allowlist) +// +// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") β†’ [{ResolvedID:"gpt-4-turbo", AliasValue:""}, +// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"}, +// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// +// allowedModels=["gpt-3.5"], aliases={}, blacklist=[] +// FilterModel("gpt-3.5") β†’ [{ResolvedID:"gpt-3.5", AliasValue:""}] +// FilterModel("gpt-4") β†’ [] +func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult { + // Step 1: resolve name β€” collect all alias matches (or the raw ID if none match). + candidates := p.resolveModelID(modelID) + + var results []FilterResult + for _, candidate := range candidates { + resolvedName := candidate.key + + // Step 2: allowlist check. + // IsRestricted() is true for both an explicit list AND an empty list (deny-all). + // Only a wildcard allowlist marker bypasses this check (pass-through). + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + allowed := false + for _, entry := range p.AllowedModels { + if matches(resolvedName, entry, p.MatchFns) { + allowed = true + break + } + } + if !allowed { + continue + } + } + + // Step 3: blacklist check β€” blacklist always wins regardless of allowlist or aliases. + if !p.Unfiltered { + blacklisted := false + for _, entry := range p.BlacklistedModels { + if matches(resolvedName, entry, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + } + + results = append(results, FilterResult{ + ResolvedID: resolvedName, + AliasValue: candidate.value, + }) + } + return results +} + +// BackfillModels adds model entries that were configured by the caller but not +// returned by the provider's API response (or not matched during filtering). +// +// The `included` map tracks model IDs (lowercased) already added during the +// filter pass, used to avoid duplicates. +// +// Two cases depending on whether the allowlist is restricted: +// +// Case A β€” allowlist restricted (caller specified explicit model names): +// +// Add each allowlist entry that is not yet in `included`, skip if blacklisted. +// If the entry has an alias mapping (aliases[entry] exists), set Alias to the +// provider-specific ID so callers can route to the right model. +// +// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"} +// "my-gpt4" not in included β†’ add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// "gpt-3.5" not in included β†’ add {ID:"openai/gpt-3.5"} +// +// Case B β€” allowlist wildcard (*) only: +// +// We don't know all model names (no explicit list), so we only backfill entries +// that were explicitly configured via aliases and not yet matched from the API. +// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard. +// +// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included +// β†’ add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// +// Blacklist always wins β€” nothing blacklisted is added in either case. +func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model { + var result []schemas.Model + + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + // Case A: backfill explicit allowlist entries not yet matched. + for _, entry := range p.AllowedModels { + if included[strings.ToLower(entry)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(entry, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + m := schemas.Model{ + ID: string(p.ProviderKey) + "/" + entry, + Name: schemas.Ptr(ToDisplayName(entry)), + } + // If this allowlist entry has an alias, surface the provider-specific ID. + for aliasKey, providerID := range p.Aliases { + if matches(entry, aliasKey, p.MatchFns) { + m.Alias = schemas.Ptr(providerID) + break + } + } + result = append(result, m) + } + return result + } + + // Case B: wildcard allowlist β€” backfill only explicitly configured aliases. + if !p.Unfiltered && len(p.Aliases) > 0 { + for aliasKey, providerID := range p.Aliases { + if included[strings.ToLower(aliasKey)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(aliasKey, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + result = append(result, schemas.Model{ + ID: string(p.ProviderKey) + "/" + aliasKey, + Name: schemas.Ptr(ToDisplayName(aliasKey)), + Alias: schemas.Ptr(providerID), + }) + } + } + + return result +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index c1b7259375..6cadc5a62c 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -178,12 +178,12 @@ func MakeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *f } // Check for timeout errors first before checking net.OpError to avoid misclassification if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check if error implements net.Error and has Timeout() == true var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -1043,7 +1043,7 @@ func MergeExtraParamsIntoJSON(jsonBody []byte, extraParams map[string]interface{ } // CheckContextAndGetRequestBody checks if the raw request body should be used, and returns it if it exists. -func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter, providerType schemas.ModelProvider) ([]byte, *schemas.BifrostError) { +func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter) ([]byte, *schemas.BifrostError) { if IsLargePayloadPassthroughEnabled(ctx) { return nil, nil } @@ -1052,15 +1052,15 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette if !ok { convertedBody, err := requestConverter() if err != nil { - return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err) } if convertedBody == nil { - return nil, NewBifrostOperationError("request body is not provided", nil, providerType) + return nil, NewBifrostOperationError("request body is not provided", nil) } jsonBody, err := MarshalSortedIndent(convertedBody, "", " ") if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -1070,7 +1070,7 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette // tool schemas and other order-sensitive JSON structures. jsonBody, err = MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -1367,10 +1367,6 @@ func NewUnsupportedOperationError(requestType schemas.RequestType, providerName Message: fmt.Sprintf("%s is not supported by %s provider", requestType, providerName), Code: schemas.Ptr("unsupported_operation"), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: requestType, - }, } } @@ -1593,37 +1589,31 @@ func ParseJSONL(data []byte, parseLine func(line []byte) error) JSONLParseResult // NewConfigurationError creates a standardized error for configuration errors. // This helper reduces code duplication across providers that have configuration errors. -func NewConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewConfigurationError(message string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostOperationError creates a standardized error for bifrost operation errors. // This helper reduces code duplication across providers that have bifrost operation errors. -func NewBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostOperationError(message string, err error) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: message, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostTimeoutError creates a standardized error for provider request timeout errors. // Sets StatusCode to 504 (Gateway Timeout) and Error.Type to RequestTimedOut, // consistent with HandleStreamTimeout for streaming requests. -func NewBifrostTimeoutError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostTimeoutError(message string, err error) *schemas.BifrostError { statusCode := 504 errorType := schemas.RequestTimedOut return &schemas.BifrostError{ @@ -1634,15 +1624,12 @@ func NewBifrostTimeoutError(message string, err error, providerType schemas.Mode Type: &errorType, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewProviderAPIError creates a standardized error for provider API errors. // This helper reduces code duplication across providers that have provider API errors. -func NewProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { +func NewProviderAPIError(message string, err error, statusCode int, errorType *string, eventID *string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -1653,20 +1640,9 @@ func NewProviderAPIError(message string, err error, statusCode int, providerType Error: err, Type: errorType, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } -// RequestMetadata contains metadata about a request for error reporting. -// This struct is used to pass request context to parseError functions. -type RequestMetadata struct { - Provider schemas.ModelProvider - Model string - RequestType schemas.RequestType -} - // ShouldSendBackRawRequest checks if the raw request should be captured. // Context overrides are intentionally restricted to asymmetric behavior: a context value can only // promote falseβ†’true and will not override a true config to false, avoiding accidental suppression. @@ -1694,17 +1670,14 @@ func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse b } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. -func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 0, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 0, + Latency: time.Since(startTime).Milliseconds(), }, } //TODO add bifrost response pooling here @@ -1715,17 +1688,14 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event -func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 1, + Latency: time.Since(startTime).Milliseconds(), }, } //TODO add bifrost response pooling here @@ -2015,9 +1985,6 @@ func HandleStreamCancellation( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2033,11 +2000,6 @@ func HandleStreamCancellation( Message: "Request cancelled: client disconnected", Type: schemas.Ptr(schemas.RequestCancelled), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2056,9 +2018,6 @@ func HandleStreamTimeout( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2074,11 +2033,6 @@ func HandleStreamTimeout( Message: "Request timed out: deadline exceeded", Type: schemas.Ptr(schemas.RequestTimedOut), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2094,9 +2048,6 @@ func ProcessAndSendError( postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStreamChunk, - requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, logger schemas.Logger, ) { // Send scanner error through channel @@ -2107,11 +2058,6 @@ func ProcessAndSendError( Message: fmt.Sprintf("Error reading stream: %v", err), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - }, } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) @@ -2144,8 +2090,6 @@ func CreateBifrostTextCompletionChunkResponse( finishReason *string, currentChunkIndex int, requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, ) *schemas.BifrostTextCompletionResponse { response := &schemas.BifrostTextCompletionResponse{ ID: id, @@ -2158,10 +2102,7 @@ func CreateBifrostTextCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2173,14 +2114,15 @@ func CreateBifrostChatCompletionChunkResponse( usage *schemas.BifrostLLMUsage, finishReason *string, currentChunkIndex int, - requestType schemas.RequestType, - providerName schemas.ModelProvider, model string, + created int, ) *schemas.BifrostChatResponse { response := &schemas.BifrostChatResponse{ - ID: id, - Object: "chat.completion.chunk", - Usage: usage, + ID: id, + Model: model, + Created: created, + Object: "chat.completion.chunk", + Usage: usage, Choices: []schemas.BifrostResponseChoice{ { FinishReason: finishReason, @@ -2190,10 +2132,7 @@ func CreateBifrostChatCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2363,10 +2302,7 @@ func aggregateListModelsResponses(responses []*schemas.BifrostListModelsResponse // extractSuccessfulListModelsResponses extracts successful responses from a results channel // and tracks per-key status information. This utility reduces code duplication across providers // for handling multi-key ListModels requests. -func extractSuccessfulListModelsResponses( - results chan schemas.ListModelsByKeyResult, - providerName schemas.ModelProvider, -) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { +func extractSuccessfulListModelsResponses(results chan schemas.ListModelsByKeyResult, provider schemas.ModelProvider) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { var successfulResponses []*schemas.BifrostListModelsResponse var keyStatuses []schemas.KeyStatus var lastError *schemas.BifrostError @@ -2384,7 +2320,7 @@ func extractSuccessfulListModelsResponses( getLogger().Warn(fmt.Sprintf("failed to list models with key %s: %s", result.KeyID, errMsg)) keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusListModelsFailed, Error: result.Err, }) @@ -2394,7 +2330,7 @@ func extractSuccessfulListModelsResponses( keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusSuccess, }) successfulResponses = append(successfulResponses, result.Response) @@ -2409,10 +2345,6 @@ func extractSuccessfulListModelsResponses( Error: &schemas.ErrorField{ Message: "all keys failed to list models", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }, } } @@ -2495,8 +2427,6 @@ func HandleMultipleListModelsRequests( // Set ExtraFields latency := time.Since(startTime) - response.ExtraFields.Provider = request.Provider - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -2653,10 +2583,10 @@ func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostRe if accumulatedResp != nil { // Use accumulated response for attributes (includes full content, tool calls, etc.) - tracer.PopulateLLMResponseAttributes(handle, accumulatedResp, err) + tracer.PopulateLLMResponseAttributes(ctx, handle, accumulatedResp, err) } else if result != nil { // Fall back to final chunk if no accumulated data (shouldn't happen normally) - tracer.PopulateLLMResponseAttributes(handle, result, err) + tracer.PopulateLLMResponseAttributes(ctx, handle, result, err) } // Finalize aggregated post-hook spans before ending the LLM span diff --git a/core/providers/vertex/embedding.go b/core/providers/vertex/embedding.go index 0fc0ad598f..54662f50fe 100644 --- a/core/providers/vertex/embedding.go +++ b/core/providers/vertex/embedding.go @@ -110,8 +110,6 @@ func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B Data: embeddings, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: schemas.Vertex, }, } } diff --git a/core/providers/vertex/errors.go b/core/providers/vertex/errors.go index 6b255835d4..e0ed7f1d3d 100644 --- a/core/providers/vertex/errors.go +++ b/core/providers/vertex/errors.go @@ -10,25 +10,13 @@ import ( "github.com/valyala/fasthttp" ) -func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { - var providerName schemas.ModelProvider - if meta != nil { - providerName = meta.Provider - } - +func parseVertexError(resp *fasthttp.Response) *schemas.BifrostError { var openAIErr schemas.BifrostError var vertexErr []VertexError decodedBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) return bifrostErr } @@ -42,13 +30,6 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseEmpty, }, } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } return bifrostErr } @@ -61,26 +42,20 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseHTML, Error: errors.New(string(decodedBody)), }, - } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + ExtraFields: schemas.BifrostErrorExtraFields{ + RawResponse: string(decodedBody), + }, } return bifrostErr } createError := func(message string) *schemas.BifrostError { - bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), providerName, nil, nil) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), nil, nil) + var rawResponse interface{} + if err := sonic.Unmarshal(decodedBody, &rawResponse); err != nil { + rawResponse = string(decodedBody) } + bifrostErr.ExtraFields.RawResponse = rawResponse return bifrostErr } @@ -93,14 +68,7 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada // Try VertexValidationError format (validation errors from Mistral endpoint) var validationErr VertexValidationError if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) return bifrostErr } if len(validationErr.Detail) > 0 { diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 28b5598022..48837563eb 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -1,12 +1,10 @@ package vertex import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "golang.org/x/text/cases" - "golang.org/x/text/language" ) // VertexRankRequest represents the Discovery Engine rank API request. @@ -56,49 +54,6 @@ type vertexRerankOptions struct { UserLabels map[string]string } -// formatDeploymentName converts a deployment alias into a human-readable name. -// It splits the alias by "-" or "_", capitalizes each word, and joins them with spaces. -// Example: "gemini-pro" β†’ "Gemini Pro", "claude_3_opus" β†’ "Claude 3 Opus" -func formatDeploymentName(alias string) string { - caser := cases.Title(language.English) - - // Try splitting by hyphen first, then underscore - var parts []string - if strings.Contains(alias, "-") { - parts = strings.Split(alias, "-") - } else if strings.Contains(alias, "_") { - parts = strings.Split(alias, "_") - } else { - // No delimiter found, just capitalize the whole string - return caser.String(strings.ToLower(alias)) - } - - // Capitalize each part - for i, part := range parts { - if part != "" { - parts[i] = caser.String(strings.ToLower(part)) - } - } - - return strings.Join(parts, " ") -} - -// findDeploymentMatch finds a matching deployment value in the deployments map. -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, customModelID string) (deploymentValue, alias string) { - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == customModelID { - return depValue, aliasKey - } - } - // Check exact match by alias/key - if deployment, ok := deployments[customModelID]; ok { - return deployment, customModelID - } - return "", "" -} - // ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format. // It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels). // @@ -114,7 +69,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -123,10 +78,22 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod Data: make([]schemas.Model, 0, len(response.Models)), } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) - // First pass: Process all models from the Vertex AI API response (custom models) + // Process all models from the Vertex AI API response (custom deployed models). + // The model ID is extracted from the endpoint URL last segment. for _, model := range response.Models { if len(model.DeployedModels) == 0 { continue @@ -142,110 +109,28 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod continue } - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - var deploymentValue, deploymentAlias string - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) + for _, result := range pipeline.FilterModel(customModelID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { + continue } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - shouldFilter = !slices.Contains(allowedModels, customModelID) - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - modelID := customModelID - - if !unfiltered && (slices.Contains(blacklistedModels, customModelID) || slices.Contains(blacklistedModels, deploymentAlias)) { - continue - } - - modelEntry := schemas.Model{ - ID: string(schemas.Vertex) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.VersionCreateTime.Unix()), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Vertex) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelEntry.ID] = true - } - } - - // Second pass: Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + alias - if addedModelIDs[modelID] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if slices.Contains(blacklistedModels, alias) { - continue - } - - modelName := formatDeploymentName(alias) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.VersionCreateTime.Unix()), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[resolvedKey] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true } } - // Third pass: Backfill allowed models that were not in the response or deployments - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + allowedModel - if addedModelIDs[modelID] { - continue - } - if slices.Contains(blacklistedModels, allowedModel) { - continue - } - - modelName := formatDeploymentName(allowedModel) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) bifrostResponse.NextPageToken = response.NextPageToken @@ -254,7 +139,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -263,8 +148,19 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a Data: make([]schemas.Model, 0, len(response.PublisherModels)), } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) for _, model := range response.PublisherModels { // Extract model ID from name (format: "publishers/google/models/gemini-1.5-pro") @@ -273,35 +169,27 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a continue } - // Filter based on allowedModels if specified - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, modelID) { - continue - } - - // Skip if already added (shouldn't happen, but safety check) - fullModelID := string(schemas.Vertex) + "/" + modelID - if addedModelIDs[fullModelID] { - continue - } - - // Extract display name from supported actions if available - displayName := modelID - if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { - displayName = model.SupportedActions.Deploy.ModelDisplayName - } - - modelEntry := schemas.Model{ - ID: fullModelID, - Name: schemas.Ptr(displayName), + for _, result := range pipeline.FilterModel(modelID) { + // Extract display name from supported actions if available + displayName := result.ResolvedID + if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { + displayName = model.SupportedActions.Deploy.ModelDisplayName + } + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(displayName), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[fullModelID] = true } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse diff --git a/core/providers/vertex/rerank.go b/core/providers/vertex/rerank.go index 74372658b2..b06430fcac 100644 --- a/core/providers/vertex/rerank.go +++ b/core/providers/vertex/rerank.go @@ -83,7 +83,7 @@ func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) } // ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format. -func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployment string, options *vertexRerankOptions) (*VertexRankRequest, error) { +func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost rerank request is nil") } @@ -132,7 +132,7 @@ func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployme rankRequest.TopN = &topN } - if trimmedModel := strings.TrimSpace(modelDeployment); trimmedModel != "" { + if trimmedModel := strings.TrimSpace(bifrostReq.Model); trimmedModel != "" { rankRequest.Model = &trimmedModel } diff --git a/core/providers/vertex/rerank_test.go b/core/providers/vertex/rerank_test.go index afd8ed225e..3f2efcec52 100644 --- a/core/providers/vertex/rerank_test.go +++ b/core/providers/vertex/rerank_test.go @@ -42,7 +42,6 @@ func TestToVertexRankRequest(t *testing.T) { TopN: schemas.Ptr(10), }, }, - "semantic-ranker-default@latest", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, @@ -77,7 +76,6 @@ func TestToVertexRankRequestTooManyRecords(t *testing.T) { Query: "q", Documents: docs, }, - "", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 2e4e225f4d..0dfda6763d 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader β€” skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -26,74 +26,74 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Add max_tokens if not present if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Remap unsupported tool versions for Vertex (e.g., web_search_20260209 β†’ web_search_20250305) jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } // Add anthropic_version if not present if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Validate tools are supported by Vertex if request.Params != nil && request.Params.Tools != nil { if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil { - return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil) } } // Convert request to Anthropic format reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } reqBody.Model = deployment @@ -109,44 +109,44 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add anthropic_version if not present (using sjson to preserve order) if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Remove model field for Vertex API (it's in URL) jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -178,29 +178,25 @@ func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID // buildResponseFromConfig builds a list models response from configured deployments and allowedModels. // This is used when the user has explicitly configured which models they want to use. -func buildResponseFromConfig(deployments map[string]string, allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { +func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse { response := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } + if blacklistedModels.IsBlockAll() { + return response + } + addedModelIDs := make(map[string]bool) - // Build allowlist set for O(1) lookup - allowedSet := make(map[string]bool, len(allowedModels)) - for _, m := range allowedModels { - allowedSet[m] = true - } - blacklistedSet := make(map[string]bool, len(blacklistedModels)) - for _, m := range blacklistedModels { - blacklistedSet[m] = true - } + restrictAllowed := allowedModels.IsRestricted() // First add models from deployments (filtered by allowedModels when set) for alias, deploymentValue := range deployments { - if len(allowedSet) > 0 && !allowedSet[alias] { + if restrictAllowed && !allowedModels.Contains(alias) { continue } - if len(blacklistedSet) > 0 && blacklistedSet[alias] { + if blacklistedModels.IsBlocked(alias) { continue } modelID := string(schemas.Vertex) + "/" + alias @@ -208,28 +204,31 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels []stri continue } - modelName := formatDeploymentName(alias) + modelName := providerUtils.ToDisplayName(alias) modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + ID: modelID, + Name: schemas.Ptr(modelName), + Alias: schemas.Ptr(deploymentValue), } response.Data = append(response.Data, modelEntry) addedModelIDs[modelID] = true } - // Then add models from allowedModels that aren't already in deployments + // Then add models from allowedModels that aren't already in deployments (only when restricted) + if !restrictAllowed { + return response + } for _, allowedModel := range allowedModels { - if len(blacklistedSet) > 0 && blacklistedSet[allowedModel] { - continue - } modelID := string(schemas.Vertex) + "/" + allowedModel if addedModelIDs[modelID] { continue } + if blacklistedModels.IsBlocked(allowedModel) { + continue + } - modelName := formatDeploymentName(allowedModel) + modelName := providerUtils.ToDisplayName(allowedModel) modelEntry := schemas.Model{ ID: modelID, Name: schemas.Ptr(modelName), diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 525e3dada5..626cf00207 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -114,9 +114,6 @@ const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" // It uses the JWT config if auth credentials are provided. // It returns an error if the token source creation fails. func getAuthTokenSource(key schemas.Key) (oauth2.TokenSource, error) { - if key.VertexKeyConfig == nil { - return nil, fmt.Errorf("vertex key config is not set") - } authCredentials := key.VertexKeyConfig.AuthCredentials var tokenSource oauth2.TokenSource if authCredentials.GetValue() == "" { @@ -176,23 +173,21 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // 1. If deployments or allowedModels are configured, return those (no API call needed) // 2. Otherwise, fetch from the publishers.models.list API endpoint (Model Garden) func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - deployments := key.VertexKeyConfig.Deployments + deployments := key.Aliases allowedModels := key.Models + if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) { + return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil + } + // If deployments or allowedModels are configured, return those directly without API call // Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved - if !request.Unfiltered && (len(deployments) > 0 || len(allowedModels) > 0) { + if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) { return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil } @@ -213,11 +208,11 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err) } // Iterate over all supported Vertex publishers to include Google, Anthropic, and Mistral models @@ -246,13 +241,14 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) if bifrostErr != nil { wait() + respBody := append([]byte(nil), resp.Body()...) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) // Non-Google publishers may not be available in all regions; skip on error if publisher != "google" { break } - return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -280,9 +276,9 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key var errorResp VertexError if err := sonic.Unmarshal(respBody, &errorResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, schemas.Vertex, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Vertex's publisher models response @@ -322,7 +318,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key PublisherModels: allPublisherModels, } - response := aggregatedResponse.ToBifrostListModelsResponse(nil, key.BlacklistedModels, request.Unfiltered) + response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequests @@ -368,18 +364,6 @@ func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -389,7 +373,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use centralized Anthropic converter reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request) if convErr != nil { @@ -399,7 +383,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) // Marshal to JSON bytes, preserving struct field order @@ -426,7 +409,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to delete model field: %w", err) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody, err := gemini.ToGeminiChatCompletionRequest(request) if err != nil { return nil, err @@ -435,7 +418,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes @@ -450,7 +432,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Marshal to JSON bytes rawBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { @@ -465,26 +446,26 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Remap unsupported tool versions for Vertex (handles raw passthrough bodies) - if schemas.IsAnthropicModel(deployment) && jsonBody != nil { + if schemas.IsAnthropicModel(request.Model) && jsonBody != nil { remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } jsonBody = remappedBody } @@ -493,43 +474,43 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsAnthropicModel(deployment) { + } else if schemas.IsAnthropicModel(request.Model) { // Claude models use Anthropic publisher if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsMistralModel(deployment) { + } else if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with rawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { // Gemini models support api key if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } else { if region == "global" { @@ -564,11 +545,11 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -597,14 +578,10 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -613,16 +590,13 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Create response object from pool anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) @@ -636,17 +610,9 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostChatResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), - } - - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment + Latency: latency.Milliseconds(), + ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } - response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -659,7 +625,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -668,12 +634,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } response := geminiResponse.ToBifrostChatResponse() - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -695,12 +655,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -723,35 +677,17 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails. func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use Anthropic-style streaming for Claude models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -766,8 +702,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment - reqBody.Stream = schemas.Ptr(true) + reqBody.Stream = new(true) // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) @@ -803,7 +738,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -813,15 +748,15 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext var remapErr error jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } } var completeURL string if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -834,11 +769,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken @@ -855,15 +790,10 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), providerName, postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { // Use Gemini-style streaming for Gemini models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -876,12 +806,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if reqBody == nil { return nil, fmt.Errorf("chat completion input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -894,12 +823,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { @@ -918,11 +847,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } @@ -940,7 +869,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -949,12 +878,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsMistralModel(deployment) { + if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with streamRawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, request.Model) } } else { // Other models use OpenAPI endpoint for gemini models @@ -974,22 +903,17 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } authHeader = map[string]string{ "Authorization": "Bearer " + token.AccessToken, } } - postRequestConverter := func(reqBody *openai.OpenAIChatRequest) *openai.OpenAIChatRequest { - reqBody.Model = deployment - return reqBody - } - // Use shared OpenAI streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, @@ -1005,8 +929,8 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -1014,40 +938,28 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Responses performs a responses request to the Vertex API. func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Claude models use Anthropic publisher var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } // Create HTTP request for streaming @@ -1068,11 +980,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1100,14 +1012,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1115,9 +1023,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1137,13 +1042,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem response := anthropicResponse.ToBifrostResponsesResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1154,12 +1055,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse } - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -1171,24 +1069,23 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -1198,11 +1095,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":generateContent") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":generateContent") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1227,11 +1124,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -1260,14 +1157,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1275,9 +1168,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1292,16 +1182,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := geminiResponse.ToResponsesBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse @@ -1319,52 +1202,33 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response, nil } } // ResponsesStream performs a streaming responses request to the Vertex API. func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -1377,22 +1241,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicResponsesStream( ctx, @@ -1406,23 +1262,18 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } // Use Gemini-style streaming for Gemini models @@ -1437,12 +1288,11 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1455,12 +1305,12 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery) @@ -1478,23 +1328,15 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Gemini return gemini.HandleGeminiResponsesStream( ctx, @@ -1508,7 +1350,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -1526,18 +1368,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // All Vertex AI embedding models use the same response format regardless of the model type. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1546,24 +1384,19 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVertexEmbeddingRequest(request), nil }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } - deployment := provider.getModelDeployment(key, request.Model) - - // Remove google/ prefix from deployment - deployment = strings.TrimPrefix(deployment, "google/") - // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Build the native Vertex embedding API endpoint - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predict") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1586,11 +1419,11 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1626,7 +1459,7 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Try to parse Vertex's error format var vertexError map[string]interface{} if err := sonic.Unmarshal(errBody, &vertexError); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errorObj, exists := vertexError["error"]; exists { @@ -1640,10 +1473,10 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem } } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), schemas.Vertex, nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1651,9 +1484,6 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostEmbeddingResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1663,28 +1493,21 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Parse Vertex's native embedding response using typed response var vertexResponse VertexEmbeddingResponse if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Use centralized Vertex converter bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if bifrostResponse.ExtraFields.ModelRequested != deployment { - bifrostResponse.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponseMap map[string]interface{} if err := sonic.Unmarshal(resp.Body(), &rawResponseMap); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ExtraFields.RawResponse = rawResponseMap } @@ -1699,30 +1522,23 @@ func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas. // Rerank performs a rerank request using Vertex Discovery Engine ranking API. func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } options, err := getVertexRerankOptions(projectID, request.Params) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } - modelDeployment := provider.getModelDeployment(key, request.Model) jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToVertexRankRequest(request, modelDeployment, options) + return ToVertexRankRequest(request, options) }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1748,11 +1564,11 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1780,11 +1596,7 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. } errorMessage := parseDiscoveryEngineErrorMessage(resp.Body()) - parsedError := parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + parsedError := parseVertexError(resp) if strings.TrimSpace(errorMessage) != "" { shouldOverride := parsedError == nil || @@ -1794,19 +1606,14 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. parsedError.Error.Message == schemas.ErrProviderResponseUnmarshal if shouldOverride { - parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), providerName, nil, nil) - parsedError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, - } + parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil) } } return nil, providerUtils.EnrichError(ctx, parsedError, jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1815,9 +1622,6 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1833,16 +1637,9 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments bifrostResponse, err := vertexResponse.ToBifrostRerankResponse(request.Documents, returnDocuments) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - if request.Model != modelDeployment { - bifrostResponse.ExtraFields.ModelDeployment = modelDeployment - } - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -1873,21 +1670,9 @@ func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, } func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1898,13 +1683,12 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -1912,7 +1696,7 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") @@ -1932,58 +1716,58 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used for fine-tuned models to pass the API key in the query string authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { // Imagen models are published models, use publishers/google/models path if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2010,11 +1794,11 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2043,14 +1827,10 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2058,16 +1838,13 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2080,12 +1857,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2108,12 +1879,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2137,20 +1902,9 @@ func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContex // ImageEdit edits images for the given input text(s) using Vertex AI. // Returns a BifrostResponse containing the images and any error that occurred. func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2161,13 +1915,12 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -2175,7 +1928,7 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") @@ -2195,19 +1948,19 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -2216,27 +1969,27 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2262,11 +2015,11 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2294,14 +2047,10 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2309,16 +2058,13 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2331,12 +2077,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2359,12 +2099,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2395,18 +2129,9 @@ func (provider *VertexProvider) ImageVariation(ctx *schemas.BifrostContext, key func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, bifrostReq.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Only Gemini models support video generation in Vertex - if !schemas.IsVeoModel(deployment) && !schemas.IsAllDigitsASCII(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", deployment), providerName) + if !schemas.IsVeoModel(bifrostReq.Model) && !schemas.IsAllDigitsASCII(bifrostReq.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", bifrostReq.Model)) } // Convert Bifrost request to Gemini format (reusing Gemini converters) @@ -2416,7 +2141,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2424,12 +2148,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used to pass the API key in the query string @@ -2440,12 +2164,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(bifrostReq.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini video generation using predictLongRunning - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predictLongRunning") + completeURL := getCompleteURLForGeminiEndpoint(bifrostReq.Model, region, projectID, projectNumber, ":predictLongRunning") // Create HTTP request req := fasthttp.AcquireRequest() @@ -2464,11 +2188,11 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2488,17 +2212,13 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: bifrostReq.Model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var operation gemini.GenerateVideosOperation @@ -2514,12 +2234,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = bifrostReq.Model - if bifrostReq.Model != deployment { - bifrostResp.ExtraFields.ModelDeployment = deployment - } - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2535,18 +2249,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // VideoRetrieve retrieves the status of a video generation operation. // Uses the fetchPredictOperation endpoint for Vertex AI. func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Construct base URL based on region @@ -2562,12 +2270,12 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // projects/PROJECT_ID/locations/REGION/publishers/google/models/MODEL_ID/operations/OPERATION_ID // We need to extract the model path from it to construct the fetchPredictOperation endpoint // Extract: projects/.../models/MODEL_ID from the operation name - taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, providerName) + taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, provider.GetProviderKey()) var modelPath string if idx := strings.Index(taskID, "/operations/"); idx != -1 { modelPath = taskID[:idx] } else { - return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil) } // Construct the URL: https://REGION-aiplatform.googleapis.com/v1/{modelPath}:fetchPredictOperation @@ -2582,7 +2290,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Create request body with operation name (using sjson to avoid map marshaling) jsonBody, err := providerUtils.SetJSONField([]byte(`{}`), "operationName", taskID) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err) } // Create HTTP request @@ -2602,11 +2310,11 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2626,10 +2334,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -2643,10 +2348,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if sendBackRawResponse { @@ -2660,9 +2363,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // First retrieves the video status to get the URL, then downloads the content. // Handles both regular URLs and data URLs (base64-encoded videos). func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2676,12 +2378,10 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte var latency time.Duration @@ -2693,7 +2393,7 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded contentType = videoResp.Videos[0].ContentType @@ -2723,11 +2423,11 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2742,19 +2442,17 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) providerResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ @@ -2764,8 +2462,6 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2803,19 +2499,6 @@ func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequ } } -func (provider *VertexProvider) getModelDeployment(key schemas.Key, model string) string { - if key.VertexKeyConfig == nil { - return model - } - - if key.VertexKeyConfig.Deployments != nil { - if deployment, ok := key.VertexKeyConfig.Deployments[model]; ok { - return deployment - } - } - return model -} - // BatchCreate is not supported by Vertex AI provider. func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) @@ -2875,25 +2558,13 @@ func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schem // CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint. // Supports Gemini models with both text and image content. func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - var ( jsonBody []byte bifrostErr *schemas.BifrostError ) - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } @@ -2904,7 +2575,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiResponsesRequest(request) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2922,38 +2592,38 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" var completeURL string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { if region == "global" { completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/count-tokens:rawPredict", projectID) } else { completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", region, projectID, region) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - completeURL = getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":countTokens") + completeURL = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":countTokens") } if completeURL == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model/deployment: %s", deployment), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model: %s", request.Model)) } req := fasthttp.AcquireRequest() @@ -2975,11 +2645,11 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3007,14 +2677,10 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -3022,16 +2688,13 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch respOwned = false return &schemas.BifrostCountTokensResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := &anthropic.AnthropicCountTokensResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -3040,12 +2703,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3068,12 +2725,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := vertexResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3138,14 +2789,9 @@ func (provider *VertexProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3211,12 +2857,12 @@ func (provider *VertexProvider) Passthrough( tokenSource, err := getAuthTokenSource(key) if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3266,7 +2912,7 @@ func (provider *VertexProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -3280,9 +2926,6 @@ func (provider *VertexProvider) Passthrough( } bifrostResponse.ExtraFields.ProviderResponseHeaders = headers - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest - bifrostResponse.ExtraFields.ModelRequested = req.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3298,13 +2941,9 @@ func (provider *VertexProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3370,13 +3009,13 @@ func (provider *VertexProvider) PassthroughStream( if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3416,9 +3055,9 @@ func (provider *VertexProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { @@ -3433,9 +3072,7 @@ func (provider *VertexProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Set stream idle timeout from provider config. @@ -3448,11 +3085,7 @@ func (provider *VertexProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3463,9 +3096,9 @@ func (provider *VertexProvider) PassthroughStream( go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -3514,7 +3147,7 @@ func (provider *VertexProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/vertex/vertex_test.go b/core/providers/vertex/vertex_test.go index 7203bf3080..420b5e58cd 100644 --- a/core/providers/vertex/vertex_test.go +++ b/core/providers/vertex/vertex_test.go @@ -38,38 +38,38 @@ func TestVertex(t *testing.T) { ImageEditModel: "imagen-3.0-capability-001", VideoGenerationModel: "veo-3.1-generate-preview", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, - VideoGeneration: false, // disabled for now because of long running operations - VideoRetrieve: false, - VideoRemix: false, - VideoDownload: false, - VideoList: false, - VideoDelete: false, - MultipleImages: true, - CompleteEnd2End: true, - FileBase64: true, - Embedding: true, - Rerank: rerankModel != "", - Reasoning: true, - PromptCaching: true, - ListModels: false, - CountTokens: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - InterleavedThinking: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, + VideoGeneration: false, // disabled for now because of long running operations + VideoRetrieve: false, + VideoRemix: false, + VideoDownload: false, + VideoList: false, + VideoDelete: false, + MultipleImages: true, + CompleteEnd2End: true, + FileBase64: true, + Embedding: true, + Rerank: rerankModel != "", + Reasoning: true, + PromptCaching: true, + ListModels: false, + CountTokens: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + InterleavedThinking: true, }, } diff --git a/core/providers/vllm/utils.go b/core/providers/vllm/utils.go index d2cefce786..ab6d694938 100644 --- a/core/providers/vllm/utils.go +++ b/core/providers/vllm/utils.go @@ -13,9 +13,6 @@ func HandleVLLMResponse[T any](responseBody []byte, response *T, requestBody []b return rawRequest, rawResponse, bifrostErr } if err := sonic.Unmarshal(responseBody, &errorResp); err == nil && errorResp.Error != nil && errorResp.Error.Message != "" { - errorResp.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: schemas.VLLM, - } return rawRequest, rawResponse, &errorResp } return rawRequest, rawResponse, nil diff --git a/core/providers/vllm/vllm.go b/core/providers/vllm/vllm.go index eab5828af9..d36aa7354b 100644 --- a/core/providers/vllm/vllm.go +++ b/core/providers/vllm/vllm.go @@ -76,9 +76,7 @@ func (provider *VLLMProvider) baseURLOrError(key schemas.Key) (string, *schemas. if u == "" { return "", providerUtils.NewBifrostOperationError( "no base URL configured: set vllm_key_config.url on the key", - nil, - provider.GetProviderKey(), - ) + nil) } return u, nil } @@ -246,9 +244,6 @@ func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas return nil, err } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -314,12 +309,14 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( statusCode := resp.StatusCode() if statusCode != fasthttp.StatusOK { - return nil, nil, nil, nil, statusCode, latency, openai.ParseOpenAIError(resp, schemas.RerankRequest, provider.GetProviderKey(), request.Model) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, nil, nil, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -336,16 +333,12 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( // Rerank performs a rerank request to vLLM's API. func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVLLMRerankRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -358,6 +351,9 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke resolvedPath = "/" + resolvedPath } + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + responsePayload, rawRequest, rawResponse, responseBody, statusCode, latency, bifrostErr := provider.callVLLMRerankEndpoint(ctx, key, request, resolvedPath, jsonData) if bifrostErr != nil && !hasPathOverride && isRerankFallbackStatus(statusCode) { var fallbackLatency time.Duration @@ -365,7 +361,7 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke latency += fallbackLatency } if bifrostErr != nil { - return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, sendBackRawRequest, sendBackRawResponse) } returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments @@ -373,19 +369,16 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke if err != nil { return nil, providerUtils.EnrichError( ctx, - providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), + providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonData, responseBody, - provider.sendBackRawRequest, - provider.sendBackRawResponse, + sendBackRawRequest, + sendBackRawResponse, ) } // Keep requested model as the canonical model in Bifrost response. bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -435,7 +428,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Use centralized converter reqBody := openai.ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -491,9 +484,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -502,7 +495,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } // Large payload streaming passthrough β€” pipe raw upstream SSE to client @@ -521,9 +514,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -563,7 +556,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -580,11 +573,6 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p _, _, bifrostErr = HandleVLLMResponse(dataBytes, &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger) return @@ -603,11 +591,8 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() diff --git a/core/providers/vllm/vllm_test.go b/core/providers/vllm/vllm_test.go index a9d7a1c17d..2f1d5b22c6 100644 --- a/core/providers/vllm/vllm_test.go +++ b/core/providers/vllm/vllm_test.go @@ -37,35 +37,35 @@ func TestVLLM(t *testing.T) { EmbeddingModel: embeddingModel, RerankModel: rerankModel, Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - Embedding: true, - Rerank: rerankModel != "", - ListModels: true, - Reasoning: true, - SpeechSynthesis: false, - SpeechSynthesisStream: false, - Transcription: true, - TranscriptionStream: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, - ImageEditStream: false, - ImageVariation: false, - ImageVariationStream: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + Embedding: true, + Rerank: rerankModel != "", + ListModels: true, + Reasoning: true, + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: true, + TranscriptionStream: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, + ImageEditStream: false, + ImageVariation: false, + ImageVariationStream: false, }, } diff --git a/core/providers/xai/errors.go b/core/providers/xai/errors.go index 78b22463e0..38a46888a8 100644 --- a/core/providers/xai/errors.go +++ b/core/providers/xai/errors.go @@ -15,7 +15,7 @@ type XAIErrorResponse struct { // ParseXAIError parses xAI-specific error responses. // xAI returns errors in format: {"code": "...", "error": "..."} // Unlike OpenAI which uses: {"error": {"message": "...", "type": "...", "code": "..."}} -func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseXAIError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse xAI error format var xaiErr XAIErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &xaiErr) @@ -35,10 +35,5 @@ func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, pro } } - // Set ExtraFields individually to preserve RawResponse from HandleProviderAPIError - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index ecf285c379..6ec5ebda6b 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -65,7 +65,7 @@ func (provider *XAIProvider) GetProviderKey() schemas.ModelProvider { // ListModels performs a list models request to xAI's API. func (provider *XAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } return openai.HandleOpenAIListModelsRequest( ctx, diff --git a/core/providers/xai/xai_test.go b/core/providers/xai/xai_test.go index 81c479bb7f..cda0616f83 100644 --- a/core/providers/xai/xai_test.go +++ b/core/providers/xai/xai_test.go @@ -32,27 +32,27 @@ func TestXAI(t *testing.T) { EmbeddingModel: "", // XAI doesn't support embedding ImageGenerationModel: "grok-2-image", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - ImageGeneration: true, - ImageGenerationStream: false, - FileBase64: false, - FileURL: false, - MultipleImages: true, - CompleteEnd2End: true, - Reasoning: true, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + ImageGeneration: true, + ImageGenerationStream: false, + FileBase64: false, + FileURL: false, + MultipleImages: true, + CompleteEnd2End: true, + Reasoning: true, + Embedding: false, + ListModels: true, }, } diff --git a/core/schemas/account.go b/core/schemas/account.go index ceaeb2de8a..3dfb16a0b4 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -1,7 +1,12 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "context" +import ( + "context" + "fmt" + "slices" + "strings" +) type KeyStatusType string @@ -10,26 +15,172 @@ const ( KeyStatusListModelsFailed KeyStatusType = "list_models_failed" ) +// WhiteList is a list of values that are allowed to be used. +// Semantics: +// - "*" (alone) means all values are allowed. +// - Empty list means nothing is allowed. +// - Non-empty list (without "*") means only the listed values are allowed. +// +// This type is used generically for any field that needs whitelist behavior +// (e.g., allowed models, allowed tools). +type WhiteList []string + +// Contains reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) Contains(value string) bool { + return slices.ContainsFunc(wl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsAllowed reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) IsAllowed(value string) bool { + return wl.IsUnrestricted() || wl.Contains(value) +} + +// IsEmpty reports whether the whitelist has no entries. +func (wl WhiteList) IsEmpty() bool { + return len(wl) == 0 +} + +// IsUnrestricted reports whether the whitelist contains only "*", +// meaning all values are allowed. +func (wl WhiteList) IsUnrestricted() bool { + return len(wl) == 1 && wl[0] == "*" +} + +// IsRestricted reports whether the whitelist contains entries other than "*", +// meaning only the listed values are allowed. +func (wl WhiteList) IsRestricted() bool { + return !wl.IsUnrestricted() +} + +// Validate checks that the whitelist is well-formed. +// Returns an error if "*" is present alongside other values, or if there are duplicate entries. +func (wl WhiteList) Validate() error { + if wl.Contains("*") && len(wl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the whitelist") + } + seen := make(map[string]struct{}, len(wl)) + for _, v := range wl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in whitelist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + +// BlackList is a list of values that are denied. +// Semantics: +// - "*" (alone) means all values are blocked. +// - Empty list means nothing is blocked. +// - Non-empty list (without "*") means only the listed values are blocked. +type BlackList []string + +func (bl BlackList) Contains(value string) bool { + return slices.ContainsFunc(bl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsBlocked reports whether value is blocked. +func (bl BlackList) IsBlocked(value string) bool { + return bl.IsBlockAll() || bl.Contains(value) +} + +// IsEmpty reports whether the blacklist has no entries (nothing is blocked). +func (bl BlackList) IsEmpty() bool { + return len(bl) == 0 +} + +// IsBlockAll reports whether the blacklist contains "*", meaning all values are blocked. +func (bl BlackList) IsBlockAll() bool { + return len(bl) == 1 && bl[0] == "*" +} + +// Validate checks that the blacklist is well-formed. +func (bl BlackList) Validate() error { + if bl.Contains("*") && len(bl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the blacklist") + } + seen := make(map[string]struct{}, len(bl)) + for _, v := range bl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in blacklist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { - ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) - Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) - Value EnvVar `json:"value"` // The actual API key value - Models []string `json:"models"` // List of models this key can access - BlacklistedModels []string `json:"blacklisted_models"` // List of models this key cannot access - Weight float64 `json:"weight"` // Weight for load balancing between multiple keys - AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration - VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration - BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration - HuggingFaceKeyConfig *HuggingFaceKeyConfig `json:"huggingface_key_config,omitempty"` // Hugging Face-specific key configuration - ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration - VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration - Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) - UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) - ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection - Status KeyStatusType `json:"status,omitempty"` // Status of key - Description string `json:"description,omitempty"` // Description of key + ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) + Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) + Value EnvVar `json:"value"` // The actual API key value + Models WhiteList `json:"models"` // List of models this key can access + BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access + Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + Aliases KeyAliases `json:"aliases,omitempty"` // Mapping of model identifiers to inference profiles + AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration + VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration + BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration + VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration + ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration + OllamaKeyConfig *OllamaKeyConfig `json:"ollama_key_config,omitempty"` // Ollama-specific key configuration + SGLKeyConfig *SGLKeyConfig `json:"sgl_key_config,omitempty"` // SGLang-specific key configuration + Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) + UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) + ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection + Status KeyStatusType `json:"status,omitempty"` // Status of key + Description string `json:"description,omitempty"` // Description of key +} + +type KeyAliases map[string]string + +func (ka KeyAliases) Validate() error { + seen := make(map[string]struct{}, len(ka)) + for from, to := range ka { + if strings.TrimSpace(from) == "" { + return fmt.Errorf("alias source cannot be empty") + } + if strings.TrimSpace(to) == "" { + return fmt.Errorf("alias target for %q cannot be empty", from) + } + if strings.TrimSpace(from) != from { + return fmt.Errorf("alias source %q cannot have leading or trailing whitespace", from) + } + if strings.TrimSpace(to) != to { + return fmt.Errorf("alias target for %q cannot have leading or trailing whitespace", from) + } + normalized := strings.ToLower(from) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate alias source %q (case-insensitive)", from) + } + seen[normalized] = struct{}{} + } + return nil +} + +func (ka KeyAliases) Resolve(model string) string { + if ka == nil { + return model + } + if alias, ok := ka[model]; ok { + return alias + } + // Fall back to case-insensitive lookup for consistency with WhiteList.Contains + for k, v := range ka { + if strings.EqualFold(k, model) { + return v + } + } + return model } type AzureAuthType string @@ -42,9 +193,8 @@ const ( // AzureKeyConfig represents the Azure-specific configuration. // It contains Azure-specific settings required for service access and deployment management. type AzureKeyConfig struct { - Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" + Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL + APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" ClientID *EnvVar `json:"client_id,omitempty"` // Azure client ID for authentication ClientSecret *EnvVar `json:"client_secret,omitempty"` // Azure client secret for authentication @@ -55,11 +205,10 @@ type AzureKeyConfig struct { // VertexKeyConfig represents the Vertex-specific configuration. // It contains Vertex-specific settings required for authentication and service access. type VertexKeyConfig struct { - ProjectID EnvVar `json:"project_id"` - ProjectNumber EnvVar `json:"project_number"` - Region EnvVar `json:"region"` - AuthCredentials EnvVar `json:"auth_credentials"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles + ProjectID EnvVar `json:"project_id"` + ProjectNumber EnvVar `json:"project_number"` + Region EnvVar `json:"region"` + AuthCredentials EnvVar `json:"auth_credentials"` } // NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string. @@ -90,21 +239,12 @@ type BedrockKeyConfig struct { ExternalID *EnvVar `json:"external_id,omitempty"` RoleSessionName *EnvVar `json:"session_name,omitempty"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles - BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations + BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations } // NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings. // To use Bedrock API Key authentication, set Value in Key struct instead. -type HuggingFaceKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - -type ReplicateKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - // VLLMKeyConfig represents the vLLM-specific key configuration. // It allows each key to target a different vLLM server URL and model name, // enabling per-key routing and round-robin load balancing across multiple vLLM instances. @@ -113,6 +253,26 @@ type VLLMKeyConfig struct { ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection) } +// ReplicateKeyConfig represents the Replicate-specific key configuration. +// It contains Replicate-specific settings required for authentication and service access. +type ReplicateKeyConfig struct { + UseDeploymentsEndpoint bool `json:"use_deployments_endpoint"` // Whether to use the deployments endpoint instead of the models endpoint +} + +// OllamaKeyConfig represents the Ollama-specific key configuration. +// It allows each key to target a different Ollama server URL, +// enabling per-key routing and round-robin load balancing across multiple Ollama instances. +type OllamaKeyConfig struct { + URL EnvVar `json:"url"` // Ollama server base URL (required, supports env. prefix) +} + +// SGLKeyConfig represents the SGLang-specific key configuration. +// It allows each key to target a different SGLang server URL, +// enabling per-key routing and round-robin load balancing across multiple SGLang instances. +type SGLKeyConfig struct { + URL EnvVar `json:"url"` // SGLang server base URL (required, supports env. prefix) +} + // Account defines the interface for managing provider accounts and their configurations. // It provides methods to access provider-specific settings, API keys, and configurations. type Account interface { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index caa14512bc..8f68d2fbad 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -160,32 +160,41 @@ type BifrostContextKey string // BifrostContextKeyRequestType is a context key for the request type. const ( - BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) - BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string - BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) - BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) - BifrostContextKeyRequestID BifrostContextKey = "request-id" // string - BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string - BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceUserID BifrostContextKey = "bifrost-governance-user-id" // string (to store the user ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) - BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) - BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string - BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string + BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) + BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string + BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) + BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) + BifrostContextKeyRequestID BifrostContextKey = "request-id" // string + BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string + BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct + + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) + + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceUserID BifrostContextKey = "bifrost-governance-user-id" // string (to store the user ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitID BifrostContextKey = "bifrost-governance-business-unit-id" // string (to store the business unit ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) + BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) + BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string + BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool @@ -204,7 +213,11 @@ const ( BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) - BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) + BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) + BifrostContextKeyMCPUserID BifrostContextKey = "bifrost-mcp-user-id" // string (per-user OAuth user identifier from X-Bf-User-Id header) + BifrostContextKeyOAuthRedirectURI BifrostContextKey = "bifrost-oauth-redirect-uri" // string (OAuth callback URL, e.g. https://host/api/oauth/callback - set by HTTP middleware) + BifrostContextKeyIsMCPGateway BifrostContextKey = "bifrost-is-mcp-gateway" // bool (true when request is being handled via the MCP gateway path) + BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) @@ -215,8 +228,16 @@ const ( BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) + BifrostContextKeyPromptStreamRequest BifrostContextKey = "bifrost-prompt-stream-request" // bool (set by prompts HTTP plugin when prompt version model_params.stream is true and body omitted stream) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries + BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) + BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request + BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns) + BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string + BifrostContextKeyRealtimeProviderSessionID BifrostContextKey = "bifrost-realtime-provider-session-id" // string + BifrostContextKeyRealtimeSource BifrostContextKey = "bifrost-realtime-source" // string ("ei" or "lm") + BifrostContextKeyRealtimeEventType BifrostContextKey = "bifrost-realtime-event-type" // string BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) @@ -227,6 +248,7 @@ const ( BifrostContextKeyVideoOutputRequested BifrostContextKey = "bifrost-video-output-requested" BifrostContextKeyValidateKeys BifrostContextKey = "bifrost-validate-keys" // bool (triggers additional key validation during provider add/update) BifrostContextKeyProviderResponseHeaders BifrostContextKey = "bifrost-provider-response-headers" // map[string]string (set by provider handlers for response header forwarding) + BifrostContextKeyMCPAddedTools BifrostContextKey = "bifrost-mcp-added-tools" // []string (set by bifrost - DO NOT SET THIS MANUALLY)) - list of tools added to the request by MCP, all the tool are in the format "clientName-toolName" BifrostContextKeyLargePayloadMode BifrostContextKey = "bifrost-large-payload-mode" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) indicates large payload streaming mode is active BifrostContextKeyLargePayloadReader BifrostContextKey = "bifrost-large-payload-reader" // io.Reader (set by bifrost - DO NOT SET THIS MANUALLY)) upstream reader for large payloads BifrostContextKeyLargePayloadContentLength BifrostContextKey = "bifrost-large-payload-content-length" // int (set by bifrost - DO NOT SET THIS MANUALLY)) content length for large payloads @@ -247,6 +269,7 @@ const ( BifrostContextKeySSEReaderFactory BifrostContextKey = "bifrost-sse-reader-factory" // *providerUtils.SSEReaderFactory (set by enterprise β€” replaces default bufio.Scanner SSE readers with streaming readers) BifrostContextKeySessionID BifrostContextKey = "bifrost-session-id" // string session ID for the request (session stickiness) BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) + BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) BifrostContextKeyMCPLogID BifrostContextKey = "bifrost-mcp-log-id" // string (unique UUID for each MCP tool log entry - set per goroutine by agent executor - DO NOT SET THIS MANUALLY) ) @@ -271,6 +294,27 @@ type RoutingEngineLogEntry struct { Timestamp int64 // Unix milliseconds } +// PluginLogEntry represents a structured log entry emitted by a plugin via ctx.Log(). +type PluginLogEntry struct { + PluginName string `json:"plugin_name"` + Level LogLevel `json:"level"` + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` // Unix milliseconds +} + +// GroupPluginLogsByName groups a flat slice of plugin log entries by plugin name. +// Returns nil if the input is empty. +func GroupPluginLogsByName(logs []PluginLogEntry) map[string][]PluginLogEntry { + if len(logs) == 0 { + return nil + } + grouped := make(map[string][]PluginLogEntry, min(len(logs), 4)) + for _, entry := range logs { + grouped[entry.PluginName] = append(grouped[entry.PluginName], entry) + } + return grouped +} + // NOTE: for custom plugin implementation dealing with streaming short circuit, // make sure to mark BifrostContextKeyStreamEndIndicator as true at the end of the stream. @@ -769,6 +813,213 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { return &BifrostResponseExtraFields{} } +func (r *BifrostResponse) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if r == nil { + return + } + resolvedModel := resolvedModelUsed + if resolvedModel == "" { + resolvedModel = originalModelRequested + } + switch { + case r.ListModelsResponse != nil: + r.ListModelsResponse.ExtraFields.RequestType = requestType + r.ListModelsResponse.ExtraFields.Provider = provider + r.ListModelsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ListModelsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TextCompletionResponse != nil: + r.TextCompletionResponse.ExtraFields.RequestType = requestType + r.TextCompletionResponse.ExtraFields.Provider = provider + r.TextCompletionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TextCompletionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ChatResponse != nil: + r.ChatResponse.ExtraFields.RequestType = requestType + r.ChatResponse.ExtraFields.Provider = provider + r.ChatResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ChatResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesResponse != nil: + r.ResponsesResponse.ExtraFields.RequestType = requestType + r.ResponsesResponse.ExtraFields.Provider = provider + r.ResponsesResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesStreamResponse != nil: + r.ResponsesStreamResponse.ExtraFields.RequestType = requestType + r.ResponsesStreamResponse.ExtraFields.Provider = provider + r.ResponsesStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.CountTokensResponse != nil: + r.CountTokensResponse.ExtraFields.RequestType = requestType + r.CountTokensResponse.ExtraFields.Provider = provider + r.CountTokensResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.CountTokensResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.EmbeddingResponse != nil: + r.EmbeddingResponse.ExtraFields.RequestType = requestType + r.EmbeddingResponse.ExtraFields.Provider = provider + r.EmbeddingResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.EmbeddingResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.RerankResponse != nil: + r.RerankResponse.ExtraFields.RequestType = requestType + r.RerankResponse.ExtraFields.Provider = provider + r.RerankResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.RerankResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechResponse != nil: + r.SpeechResponse.ExtraFields.RequestType = requestType + r.SpeechResponse.ExtraFields.Provider = provider + r.SpeechResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechStreamResponse != nil: + r.SpeechStreamResponse.ExtraFields.RequestType = requestType + r.SpeechStreamResponse.ExtraFields.Provider = provider + r.SpeechStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionResponse != nil: + r.TranscriptionResponse.ExtraFields.RequestType = requestType + r.TranscriptionResponse.ExtraFields.Provider = provider + r.TranscriptionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionStreamResponse != nil: + r.TranscriptionStreamResponse.ExtraFields.RequestType = requestType + r.TranscriptionStreamResponse.ExtraFields.Provider = provider + r.TranscriptionStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationResponse != nil: + r.ImageGenerationResponse.ExtraFields.RequestType = requestType + r.ImageGenerationResponse.ExtraFields.Provider = provider + r.ImageGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationStreamResponse != nil: + r.ImageGenerationStreamResponse.ExtraFields.RequestType = requestType + r.ImageGenerationStreamResponse.ExtraFields.Provider = provider + r.ImageGenerationStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoGenerationResponse != nil: + r.VideoGenerationResponse.ExtraFields.RequestType = requestType + r.VideoGenerationResponse.ExtraFields.Provider = provider + r.VideoGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDownloadResponse != nil: + r.VideoDownloadResponse.ExtraFields.RequestType = requestType + r.VideoDownloadResponse.ExtraFields.Provider = provider + r.VideoDownloadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDownloadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoListResponse != nil: + r.VideoListResponse.ExtraFields.RequestType = requestType + r.VideoListResponse.ExtraFields.Provider = provider + r.VideoListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDeleteResponse != nil: + r.VideoDeleteResponse.ExtraFields.RequestType = requestType + r.VideoDeleteResponse.ExtraFields.Provider = provider + r.VideoDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileUploadResponse != nil: + r.FileUploadResponse.ExtraFields.RequestType = requestType + r.FileUploadResponse.ExtraFields.Provider = provider + r.FileUploadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileUploadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileListResponse != nil: + r.FileListResponse.ExtraFields.RequestType = requestType + r.FileListResponse.ExtraFields.Provider = provider + r.FileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileRetrieveResponse != nil: + r.FileRetrieveResponse.ExtraFields.RequestType = requestType + r.FileRetrieveResponse.ExtraFields.Provider = provider + r.FileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileDeleteResponse != nil: + r.FileDeleteResponse.ExtraFields.RequestType = requestType + r.FileDeleteResponse.ExtraFields.Provider = provider + r.FileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileContentResponse != nil: + r.FileContentResponse.ExtraFields.RequestType = requestType + r.FileContentResponse.ExtraFields.Provider = provider + r.FileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCreateResponse != nil: + r.BatchCreateResponse.ExtraFields.RequestType = requestType + r.BatchCreateResponse.ExtraFields.Provider = provider + r.BatchCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchListResponse != nil: + r.BatchListResponse.ExtraFields.RequestType = requestType + r.BatchListResponse.ExtraFields.Provider = provider + r.BatchListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchRetrieveResponse != nil: + r.BatchRetrieveResponse.ExtraFields.RequestType = requestType + r.BatchRetrieveResponse.ExtraFields.Provider = provider + r.BatchRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCancelResponse != nil: + r.BatchCancelResponse.ExtraFields.RequestType = requestType + r.BatchCancelResponse.ExtraFields.Provider = provider + r.BatchCancelResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCancelResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchDeleteResponse != nil: + r.BatchDeleteResponse.ExtraFields.RequestType = requestType + r.BatchDeleteResponse.ExtraFields.Provider = provider + r.BatchDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchResultsResponse != nil: + r.BatchResultsResponse.ExtraFields.RequestType = requestType + r.BatchResultsResponse.ExtraFields.Provider = provider + r.BatchResultsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchResultsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerCreateResponse != nil: + r.ContainerCreateResponse.ExtraFields.RequestType = requestType + r.ContainerCreateResponse.ExtraFields.Provider = provider + r.ContainerCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerListResponse != nil: + r.ContainerListResponse.ExtraFields.RequestType = requestType + r.ContainerListResponse.ExtraFields.Provider = provider + r.ContainerListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerRetrieveResponse != nil: + r.ContainerRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerRetrieveResponse.ExtraFields.Provider = provider + r.ContainerRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerDeleteResponse != nil: + r.ContainerDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerDeleteResponse.ExtraFields.Provider = provider + r.ContainerDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileCreateResponse != nil: + r.ContainerFileCreateResponse.ExtraFields.RequestType = requestType + r.ContainerFileCreateResponse.ExtraFields.Provider = provider + r.ContainerFileCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileListResponse != nil: + r.ContainerFileListResponse.ExtraFields.RequestType = requestType + r.ContainerFileListResponse.ExtraFields.Provider = provider + r.ContainerFileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileRetrieveResponse != nil: + r.ContainerFileRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerFileRetrieveResponse.ExtraFields.Provider = provider + r.ContainerFileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileContentResponse != nil: + r.ContainerFileContentResponse.ExtraFields.RequestType = requestType + r.ContainerFileContentResponse.ExtraFields.Provider = provider + r.ContainerFileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileDeleteResponse != nil: + r.ContainerFileDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerFileDeleteResponse.ExtraFields.Provider = provider + r.ContainerFileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.PassthroughResponse != nil: + r.PassthroughResponse.ExtraFields.RequestType = requestType + r.PassthroughResponse.ExtraFields.Provider = provider + r.PassthroughResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.PassthroughResponse.ExtraFields.ResolvedModelUsed = resolvedModel + } +} + // BifrostMCPResponse is the response struct for all MCP responses. // only ONE of the following fields should be set: // - ChatMessage @@ -783,10 +1034,10 @@ type BifrostMCPResponse struct { type BifrostResponseExtraFields struct { RequestType RequestType `json:"request_type"` Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - ModelDeployment string `json:"model_deployment,omitempty"` // only present for providers which use model deployments (e.g. Azure, Bedrock) - Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) + Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses RawRequest interface{} `json:"raw_request,omitempty"` RawResponse interface{} `json:"raw_response,omitempty"` CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` @@ -880,6 +1131,20 @@ type BifrostError struct { ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } +func (e *BifrostError) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if e == nil { + return + } + e.ExtraFields.RequestType = requestType + e.ExtraFields.Provider = provider + e.ExtraFields.OriginalModelRequested = originalModelRequested + if resolvedModelUsed != "" { + e.ExtraFields.ResolvedModelUsed = resolvedModelUsed + } else { + e.ExtraFields.ResolvedModelUsed = originalModelRequested + } +} + // StreamControl represents stream control options. type StreamControl struct { LogError *bool `json:"log_error,omitempty"` // Optional: Controls logging of error @@ -953,11 +1218,13 @@ func (e *ErrorField) UnmarshalJSON(data []byte) error { // BifrostErrorExtraFields contains additional fields in an error response. type BifrostErrorExtraFields struct { - Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - RequestType RequestType `json:"request_type,omitempty"` - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` + RequestType RequestType `json:"request_type,omitempty"` + RawRequest any `json:"raw_request,omitempty"` + RawResponse any `json:"raw_response,omitempty"` + LiteLLMCompat bool `json:"litellm_compat,omitempty"` + KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication } diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index 221390ef17..b864349213 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -63,7 +63,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -96,7 +97,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -132,7 +134,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -149,13 +152,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } diff --git a/core/schemas/context.go b/core/schemas/context.go index 1ff4663eae..6a701eb5eb 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -26,6 +26,28 @@ var reservedKeys = []any{ BifrostContextKeyDeferTraceCompletion, } +// pluginLogStore holds plugin log entries accumulated during request processing. +// It is shared between the root BifrostContext and all scoped contexts derived from it. +// Uses a flat slice (not map) to minimize heap allocations. +type pluginLogStore struct { + mu sync.Mutex + logs []PluginLogEntry +} + +// pluginLogStorePool pools pluginLogStore instances to reduce per-request allocations. +var pluginLogStorePool = sync.Pool{ + New: func() any { + return &pluginLogStore{logs: make([]PluginLogEntry, 0, 8)} + }, +} + +// pluginScopePool pools BifrostContext instances used as scoped plugin contexts. +var pluginScopePool = sync.Pool{ + New: func() any { + return &BifrostContext{} + }, +} + // BifrostContext is a custom context.Context implementation that tracks user-set values. // It supports deadlines, can be derived from other contexts, and provides layered // value inheritance when derived from another BifrostContext. @@ -40,6 +62,11 @@ type BifrostContext struct { userValues map[any]any valuesMu sync.RWMutex blockRestrictedWrites atomic.Bool + + // Plugin scoping fields + pluginScope *string // Non-nil when this is a scoped plugin context + pluginLogs atomic.Pointer[pluginLogStore] // Shared log store; lazily initialized on root, shared by scoped contexts + valueDelegate *BifrostContext // For scoped contexts: delegate Value/SetValue to this root context } // NewBifrostContext creates a new BifrostContext with the given parent context and deadline. @@ -166,8 +193,12 @@ func (bc *BifrostContext) cancel(err error) { } // Deadline returns the deadline for this context. +// For scoped contexts, delegates to the root context. // If both this context and the parent have deadlines, the earlier one is returned. func (bc *BifrostContext) Deadline() (time.Time, bool) { + if bc.valueDelegate != nil { + return bc.valueDelegate.Deadline() + } parentDeadline, parentHasDeadline := bc.parent.Deadline() if !bc.hasDeadline && !parentHasDeadline { @@ -195,16 +226,24 @@ func (bc *BifrostContext) Done() <-chan struct{} { } // Err returns the error explaining why the context was cancelled. +// For scoped contexts, delegates to the root context. // Returns nil if the context has not been cancelled. func (bc *BifrostContext) Err() error { + if bc.valueDelegate != nil { + return bc.valueDelegate.Err() + } bc.errMu.RLock() defer bc.errMu.RUnlock() return bc.err } // Value returns the value associated with the key. -// It first checks the internal userValues map, then delegates to the parent context. +// For scoped contexts, delegates to the root context via valueDelegate. +// Otherwise checks the internal userValues map, then delegates to the parent context. func (bc *BifrostContext) Value(key any) any { + if bc.valueDelegate != nil { + return bc.valueDelegate.Value(key) + } bc.valuesMu.RLock() if val, ok := bc.userValues[key]; ok { bc.valuesMu.RUnlock() @@ -212,12 +251,21 @@ func (bc *BifrostContext) Value(key any) any { } bc.valuesMu.RUnlock() + if bc.parent == nil { + return nil + } + return bc.parent.Value(key) } // SetValue sets a value in the internal userValues map. +// For scoped contexts, delegates to the root context via valueDelegate. // This is thread-safe and can be called concurrently. func (bc *BifrostContext) SetValue(key, value any) { + if bc.valueDelegate != nil { + bc.valueDelegate.SetValue(key, value) + return + } // Check if the key is a reserved key if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { // we silently drop writes for these reserved keys @@ -232,7 +280,12 @@ func (bc *BifrostContext) SetValue(key, value any) { } // ClearValue clears a value from the internal userValues map. +// For scoped contexts, delegates to the root context via valueDelegate. func (bc *BifrostContext) ClearValue(key any) { + if bc.valueDelegate != nil { + bc.valueDelegate.ClearValue(key) + return + } // Check if the key is a reserved key if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { // we silently drop writes for these reserved keys @@ -245,8 +298,12 @@ func (bc *BifrostContext) ClearValue(key any) { } } -// GetAndSetValue gets a value from the internal userValues map and sets it +// GetAndSetValue gets a value from the internal userValues map and sets it. +// For scoped contexts, delegates to the root context via valueDelegate. func (bc *BifrostContext) GetAndSetValue(key any, value any) any { + if bc.valueDelegate != nil { + return bc.valueDelegate.GetAndSetValue(key, value) + } bc.valuesMu.Lock() defer bc.valuesMu.Unlock() // Check if the key is a reserved key @@ -340,3 +397,104 @@ func AppendToContextList[T any](ctx *BifrostContext, key BifrostContextKey, valu } ctx.SetValue(key, append(existingValues, value)) } + +// WithPluginScope returns a lightweight scoped BifrostContext from the pool. +// The scoped context shares the root's pluginLogs store and delegates all +// Value/SetValue operations to the root context. +// Call ReleasePluginScope() when done to return the scoped context to the pool. +func (bc *BifrostContext) WithPluginScope(name *string) *BifrostContext { + // Lazily initialize the plugin log store on the root context (CAS to avoid race) + if bc.pluginLogs.Load() == nil { + newStore := pluginLogStorePool.Get().(*pluginLogStore) + if !bc.pluginLogs.CompareAndSwap(nil, newStore) { + // Another goroutine initialized first β€” return unused store to pool + pluginLogStorePool.Put(newStore) + } + } + + scoped := pluginScopePool.Get().(*BifrostContext) + scoped.parent = bc.parent + scoped.done = bc.done + scoped.pluginScope = name + scoped.pluginLogs.Store(bc.pluginLogs.Load()) + scoped.valueDelegate = bc + return scoped +} + +// ReleasePluginScope returns a scoped context to the pool. +// Safe no-op if called on a non-scoped context. +// Do not use the scoped context after calling this method. +func (bc *BifrostContext) ReleasePluginScope() { + if bc.valueDelegate == nil { + return // not a scoped context + } + bc.parent = nil + bc.done = nil + bc.pluginScope = nil + bc.pluginLogs.Store(nil) + bc.valueDelegate = nil + pluginScopePool.Put(bc) +} + +// Log appends a structured log entry for the current plugin scope. +// No-op if the context is not scoped to a plugin or has no log store. +func (bc *BifrostContext) Log(level LogLevel, msg string) { + store := bc.pluginLogs.Load() + if bc.pluginScope == nil || store == nil { + return + } + store.mu.Lock() + store.logs = append(store.logs, PluginLogEntry{ + PluginName: *bc.pluginScope, + Level: level, + Message: msg, + Timestamp: time.Now().UnixMilli(), + }) + store.mu.Unlock() +} + +// GetPluginLogs returns a deep copy of all accumulated plugin log entries. +// Thread-safe. Returns nil if no logs have been recorded. +func (bc *BifrostContext) GetPluginLogs() []PluginLogEntry { + store := bc.pluginLogs.Load() + if store == nil { + return nil + } + store.mu.Lock() + defer store.mu.Unlock() + if len(store.logs) == 0 { + return nil + } + copied := make([]PluginLogEntry, len(store.logs)) + copy(copied, store.logs) + return copied +} + +// DrainPluginLogs transfers ownership of the plugin log slice to the caller. +// The internal log store is returned to the pool after draining. +// Returns nil if no logs have been recorded. +// This should be called once on the root context after all plugin hooks have completed. +func (bc *BifrostContext) DrainPluginLogs() []PluginLogEntry { + if bc.valueDelegate != nil { + return nil // scoped contexts must not drain the shared log store + } + store := bc.pluginLogs.Load() + if store == nil { + return nil + } + bc.pluginLogs.Store(nil) + + store.mu.Lock() + logs := store.logs + // Reset with fresh pre-allocated slice before returning to pool + store.logs = make([]PluginLogEntry, 0, 8) + store.mu.Unlock() + + // Return the store to the pool for reuse + pluginLogStorePool.Put(store) + + if len(logs) == 0 { + return nil + } + return logs +} diff --git a/core/schemas/context_test.go b/core/schemas/context_test.go index 75e52e2061..108da2ced0 100644 --- a/core/schemas/context_test.go +++ b/core/schemas/context_test.go @@ -207,3 +207,125 @@ func TestNewBifrostContext_NilParent(t *testing.T) { t.Errorf("Cancelled context should have Canceled error, got %v", ctx.Err()) } } + +// Plugin logging tests + +func TestPluginLog_NoScopeIsNoop(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + ctx.Log(LogLevelInfo, "should be ignored") + logs := ctx.GetPluginLogs() + if logs != nil { + t.Errorf("expected nil logs without plugin scope, got %v", logs) + } +} + +func TestPluginLog_SinglePlugin(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + name := "test-plugin" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "hello") + scoped.Log(LogLevelError, "oops") + scoped.ReleasePluginScope() + + logs := ctx.GetPluginLogs() + if len(logs) != 2 { + t.Fatalf("expected 2 logs, got %d", len(logs)) + } + if logs[0].PluginName != "test-plugin" || logs[0].Level != LogLevelInfo || logs[0].Message != "hello" { + t.Errorf("unexpected first log: %+v", logs[0]) + } + if logs[1].Level != LogLevelError || logs[1].Message != "oops" { + t.Errorf("unexpected second log: %+v", logs[1]) + } +} + +func TestPluginLog_MultiplePlugins(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + + name1 := "plugin-a" + s1 := ctx.WithPluginScope(&name1) + s1.Log(LogLevelDebug, "a-msg") + s1.ReleasePluginScope() + + name2 := "plugin-b" + s2 := ctx.WithPluginScope(&name2) + s2.Log(LogLevelWarn, "b-msg") + s2.ReleasePluginScope() + + logs := ctx.GetPluginLogs() + if len(logs) != 2 { + t.Fatalf("expected 2 logs, got %d", len(logs)) + } + if logs[0].PluginName != "plugin-a" { + t.Errorf("expected plugin-a, got %s", logs[0].PluginName) + } + if logs[1].PluginName != "plugin-b" { + t.Errorf("expected plugin-b, got %s", logs[1].PluginName) + } +} + +func TestPluginLog_DrainTransfersOwnership(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + name := "drain-test" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "msg1") + scoped.ReleasePluginScope() + + drained := ctx.DrainPluginLogs() + if len(drained) != 1 { + t.Fatalf("expected 1 drained log, got %d", len(drained)) + } + + // After drain, GetPluginLogs should return nil + after := ctx.GetPluginLogs() + if after != nil { + t.Errorf("expected nil after drain, got %v", after) + } + + // Second drain should return nil + second := ctx.DrainPluginLogs() + if second != nil { + t.Errorf("expected nil on second drain, got %v", second) + } +} + +func TestPluginLog_ScopedContextValueDelegation(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + ctx.SetValue(BifrostContextKeyTraceID, "trace-123") + + name := "delegate-test" + scoped := ctx.WithPluginScope(&name) + + // Scoped should read from root + val := scoped.Value(BifrostContextKeyTraceID) + if val != "trace-123" { + t.Errorf("expected trace-123, got %v", val) + } + + // Scoped should write to root + type testContextKey string + const customKey testContextKey = "custom-key" + scoped.SetValue(customKey, "custom-val") + if ctx.Value(customKey) != "custom-val" { + t.Errorf("SetValue on scoped did not delegate to root") + } + + scoped.ReleasePluginScope() +} + +func TestPluginLog_PoolReuse(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + + // Create and release multiple scoped contexts to exercise the pool + for i := 0; i < 100; i++ { + name := "pool-test" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "pooled") + scoped.ReleasePluginScope() + } + + logs := ctx.DrainPluginLogs() + if len(logs) != 100 { + t.Errorf("expected 100 logs from pool reuse, got %d", len(logs)) + } +} diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go index 1d8890dd1f..9ca2fb2cdf 100644 --- a/core/schemas/embedding.go +++ b/core/schemas/embedding.go @@ -116,14 +116,16 @@ type EmbeddingParameters struct { type EmbeddingData struct { Index int `json:"index"` Object string `json:"object"` // "embedding" - Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64 or [][]float64 + Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32 } type EmbeddingStruct struct { // Embedding responses preserve provider precision in normalized API output. - EmbeddingStr *string - EmbeddingArray []float64 - Embedding2DArray [][]float64 + EmbeddingStr *string + EmbeddingArray []float64 + Embedding2DArray [][]float64 + EmbeddingInt8Array []int8 // for int8 / binary formats + EmbeddingInt32Array []int32 // for uint8 / ubinary formats } func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { @@ -136,6 +138,12 @@ func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { if be.Embedding2DArray != nil { return MarshalSorted(be.Embedding2DArray) } + if be.EmbeddingInt8Array != nil { + return Marshal(be.EmbeddingInt8Array) + } + if be.EmbeddingInt32Array != nil { + return Marshal(be.EmbeddingInt32Array) + } return nil, fmt.Errorf("no embedding found") } @@ -161,5 +169,19 @@ func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { return nil } - return fmt.Errorf("embedding field is neither a string nor an array of float64 nor a 2D array of float64") + // Try to unmarshal as a direct array of int8 + var int8Content []int8 + if err := Unmarshal(data, &int8Content); err == nil { + be.EmbeddingInt8Array = int8Content + return nil + } + + // Try to unmarshal as a direct array of int32 + var int32Content []int32 + if err := Unmarshal(data, &int32Content); err == nil { + be.EmbeddingInt32Array = int32Content + return nil + } + + return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32") } diff --git a/core/schemas/envvar.go b/core/schemas/envvar.go index 6c5f996f08..c8fe249699 100644 --- a/core/schemas/envvar.go +++ b/core/schemas/envvar.go @@ -117,6 +117,9 @@ func (e *EnvVar) Equals(other *EnvVar) bool { // Redacted returns a new SecretKey with the value redacted. func (e *EnvVar) Redacted() *EnvVar { + if e == nil { + return nil + } if e.Val == "" { return &EnvVar{ Val: "", @@ -144,6 +147,34 @@ func (e *EnvVar) Redacted() *EnvVar { } } +// MarshalJSON serializes the EnvVar to JSON. +// SECURITY: When the value was sourced from an environment variable, the resolved +// value is automatically redacted before being serialized. This ensures that secrets +// injected via env vars are never leaked through any JSON API response, regardless +// of whether the surrounding code remembered to call Redacted() explicitly. +// +// Plain (non-env) values are still emitted as-is β€” callers that want to mask those +// must continue using Redacted() at the field level (this matches the existing +// per-provider redaction logic). +// +// This does NOT affect: +// - GORM persistence (uses the Value() driver method, not JSON) +// - Encryption (operates on the Val field directly) +// - Internal LLM request paths (use GetValue() directly) +func (e EnvVar) MarshalJSON() ([]byte, error) { + type envVarAlias EnvVar + out := envVarAlias(e) + if e.FromEnv { + // Redact the resolved value but keep the env var reference and from_env flag + // so the UI still knows which env var backs this field. + redacted := e.Redacted() + if redacted != nil { + out = envVarAlias(*redacted) + } + } + return sonic.Marshal(out) +} + // UnmarshalJSON unmarshals the value from JSON. func (e *EnvVar) UnmarshalJSON(data []byte) error { // This is always going to be value @@ -259,6 +290,17 @@ func (e *EnvVar) IsFromEnv() bool { return e.FromEnv } +// IsSet returns true if the EnvVar has a resolved value or an environment variable reference. +// This should be used instead of GetValue() != "" when checking whether a field was configured, +// because env var references may have an empty Val before resolution (e.g., when the env var +// is not available in the current environment). +func (e *EnvVar) IsSet() bool { + if e == nil { + return false + } + return e.Val != "" || e.EnvVar != "" +} + // GetValue returns the value. func (e *EnvVar) GetValue() string { if e == nil { @@ -298,3 +340,14 @@ func (e *EnvVar) CoerceBool(defaultValue bool) bool { } return val } + +// IsDefined returns true if the EnvVar has a source (static value or env key) +func (e *EnvVar) IsDefined() bool { + if e == nil { + return false + } + if e.IsFromEnv() { + return e.EnvVar != "" + } + return e.Val != "" +} diff --git a/core/schemas/envvar_test.go b/core/schemas/envvar_test.go index 5a451b5058..9b22673ae8 100644 --- a/core/schemas/envvar_test.go +++ b/core/schemas/envvar_test.go @@ -419,3 +419,191 @@ func TestEnvVar_IsRedacted(t *testing.T) { }) } } + +// TestEnvVar_IsSet verifies the semantic difference between GetValue() != "" and IsSet(). +// IsSet() must return true when the EnvVar references an env var (regardless of whether +// that env var has been resolved to a non-empty Val). This is the property that the +// BeforeSave hooks rely on so env var references survive persistence. +func TestEnvVar_IsSet(t *testing.T) { + tests := []struct { + name string + input *EnvVar + expected bool + }{ + { + name: "nil envvar", + input: nil, + expected: false, + }, + { + name: "completely empty", + input: &EnvVar{}, + expected: false, + }, + { + name: "only Val set (plain value)", + input: &EnvVar{Val: "abc"}, + expected: true, + }, + { + name: "only EnvVar reference set (env not resolved on this server)", + input: &EnvVar{EnvVar: "env.MISSING", FromEnv: true}, + expected: true, + }, + { + name: "Val and EnvVar both set (env was resolved)", + input: &EnvVar{Val: "resolved-secret", EnvVar: "env.X", FromEnv: true}, + expected: true, + }, + { + name: "FromEnv true but no reference and no value", + input: &EnvVar{FromEnv: true}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.input.IsSet(); got != tt.expected { + t.Errorf("IsSet() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues verifies that any EnvVar marshaled +// to JSON with FromEnv=true is automatically masked, regardless of whether the +// surrounding code remembered to call Redacted() explicitly. This is the defense-in-depth +// guarantee that prevents env-resolved secrets from leaking through unredacted fields. +func TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues(t *testing.T) { + tests := []struct { + name string + input EnvVar + wantValue string + wantEnvVar string + wantFromEnv bool + }{ + { + name: "env-backed long secret is redacted", + input: EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.OPENAI_API_KEY", FromEnv: true}, + wantValue: "sk-1************************mnop", + wantEnvVar: "env.OPENAI_API_KEY", + wantFromEnv: true, + }, + { + name: "env-backed short secret is fully masked", + input: EnvVar{Val: "12345678", EnvVar: "env.SHORT", FromEnv: true}, + wantValue: "********", + wantEnvVar: "env.SHORT", + wantFromEnv: true, + }, + { + name: "env-backed unresolved on this server keeps empty value", + input: EnvVar{Val: "", EnvVar: "env.MISSING", FromEnv: true}, + wantValue: "", + wantEnvVar: "env.MISSING", + wantFromEnv: true, + }, + { + name: "plain value (not from env) is NOT redacted", + input: EnvVar{Val: "2024-10-21", EnvVar: "", FromEnv: false}, + wantValue: "2024-10-21", + wantEnvVar: "", + wantFromEnv: false, + }, + { + name: "empty plain value passes through", + input: EnvVar{Val: "", EnvVar: "", FromEnv: false}, + wantValue: "", + wantEnvVar: "", + wantFromEnv: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.input) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got struct { + Value string `json:"value"` + EnvVar string `json:"env_var"` + FromEnv bool `json:"from_env"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal of marshaled output failed: %v", err) + } + if got.Value != tt.wantValue { + t.Errorf("value: got %q, want %q", got.Value, tt.wantValue) + } + if got.EnvVar != tt.wantEnvVar { + t.Errorf("env_var: got %q, want %q", got.EnvVar, tt.wantEnvVar) + } + if got.FromEnv != tt.wantFromEnv { + t.Errorf("from_env: got %v, want %v", got.FromEnv, tt.wantFromEnv) + } + }) + } +} + +// TestEnvVar_MarshalJSON_DoesNotMutateOriginal ensures the auto-redaction in MarshalJSON +// does not mutate the receiver. The inference path calls GetValue() to build the actual +// HTTP request to the LLM provider, so the original Val must remain intact. +func TestEnvVar_MarshalJSON_DoesNotMutateOriginal(t *testing.T) { + original := EnvVar{Val: "real-secret-value", EnvVar: "env.SECRET", FromEnv: true} + if _, err := json.Marshal(original); err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if original.Val != "real-secret-value" { + t.Errorf("MarshalJSON mutated Val: got %q, want %q", original.Val, "real-secret-value") + } + if original.GetValue() != "real-secret-value" { + t.Errorf("GetValue() returns mutated value: got %q", original.GetValue()) + } +} + +// TestEnvVar_MarshalJSON_RoundTripIsRedacted verifies that a marshaled-then-unmarshaled +// env-backed EnvVar is recognized as redacted. The merge logic in provider_keys.go relies +// on this so it can detect "the UI sent back the same redacted value, don't overwrite". +func TestEnvVar_MarshalJSON_RoundTripIsRedacted(t *testing.T) { + original := EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.KEY", FromEnv: true} + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var roundTripped EnvVar + if err := json.Unmarshal(data, &roundTripped); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !roundTripped.IsRedacted() { + t.Errorf("Round-tripped env-backed value should be IsRedacted, got Val=%q", roundTripped.Val) + } + if roundTripped.EnvVar != "env.KEY" { + t.Errorf("env_var reference lost in round-trip: got %q, want %q", roundTripped.EnvVar, "env.KEY") + } +} + +// TestEnvVar_MarshalJSON_DoesNotAffectGetValue is a critical safety net: marshaling an +// EnvVar to JSON must NOT change what GetValue() returns. The inference path uses +// GetValue() to build outgoing LLM requests; if marshaling were to mutate the value, +// every request after a UI fetch would silently start using the redacted mask as the +// API key. +func TestEnvVar_MarshalJSON_DoesNotAffectGetValue(t *testing.T) { + os.Setenv("MY_REAL_API_KEY", "sk-real-secret-1234567890abcdef") + defer os.Unsetenv("MY_REAL_API_KEY") + + ev := NewEnvVar("env.MY_REAL_API_KEY") + if ev.GetValue() != "sk-real-secret-1234567890abcdef" { + t.Fatalf("setup: GetValue() = %q, want resolved env value", ev.GetValue()) + } + + // Marshaling would redact in the JSON output, but must not touch the in-memory Val. + if _, err := json.Marshal(ev); err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + if ev.GetValue() != "sk-real-secret-1234567890abcdef" { + t.Errorf("GetValue() returns mutated value after MarshalJSON: got %q", ev.GetValue()) + } +} diff --git a/core/schemas/images.go b/core/schemas/images.go index d16df42a10..8cdff83372 100644 --- a/core/schemas/images.go +++ b/core/schemas/images.go @@ -71,6 +71,11 @@ type BifrostImageGenerationResponse struct { func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) { numInputImages, size, quality := getNumInputImagesSizeAndQualityFromRequest(req) + // Backfill Model if not returned by the provider + if r.Model == "" { + r.Model = getModelFromRequest(req) + } + // Backfill NumInputImages if numInputImages > 0 { if r.Usage == nil { @@ -96,6 +101,22 @@ func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) { } } +// getModelFromRequest extracts the model from any image-related request. +func getModelFromRequest(req *BifrostRequest) string { + if req == nil { + return "" + } + switch { + case req.ImageGenerationRequest != nil: + return req.ImageGenerationRequest.Model + case req.ImageEditRequest != nil: + return req.ImageEditRequest.Model + case req.ImageVariationRequest != nil: + return req.ImageVariationRequest.Model + } + return "" +} + // getNumInputImagesSizeAndQualityFromRequest extracts request params for cost calculation. // Quality is only returned when it is one of low, medium, high, auto. func getNumInputImagesSizeAndQualityFromRequest(req *BifrostRequest) (numInputImages int, size string, quality string) { @@ -151,10 +172,12 @@ func normalizeImageQuality(q string) string { } type ImageGenerationResponseParameters struct { - Background string `json:"background,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` + Background string `json:"background,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + FinishReasons []*string `json:"finish_reasons,omitempty"` + Seeds []int `json:"seeds,omitempty"` } type ImageData struct { @@ -254,7 +277,7 @@ type ImageInput struct { } type ImageEditParameters struct { - Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", + Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", "remove_background", "erase_object", "recolor", "search_replace", "control_sketch", "control_structure", "style_guide", "style_transfer", "upscale_fast", "upscale_creative", "upscale_conservative" Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto" InputFidelity *string `json:"input_fidelity,omitempty"` // "low", "high" Mask []byte `json:"mask,omitempty"` diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 72b43cbc8e..af87cdc743 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -20,8 +20,25 @@ var ( ErrOAuth2TokenExpired = errors.New("oauth2 token expired") ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid") ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed") + ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session") + ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server") + ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired") ) +// MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires +// the user to authenticate before tool execution can proceed. +type MCPUserOAuthRequiredError struct { + MCPClientID string `json:"mcp_client_id"` + MCPClientName string `json:"mcp_client_name"` + AuthorizeURL string `json:"authorize_url"` + SessionID string `json:"session_id"` + Message string `json:"message"` +} + +func (e *MCPUserOAuthRequiredError) Error() string { + return e.Message +} + // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { @@ -46,9 +63,10 @@ type MCPConfig struct { } type MCPToolManagerConfig struct { - ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` - MaxAgentDepth int `json:"max_agent_depth"` - CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` + MaxAgentDepth int `json:"max_agent_depth"` + CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + DisableAutoToolInject bool `json:"disable_auto_tool_inject,omitempty"` // When true, MCP tools are not injected into requests by default } const ( @@ -68,41 +86,48 @@ const ( type MCPAuthType string const ( - MCPAuthTypeNone MCPAuthType = "none" // No authentication - MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.) - MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication + MCPAuthTypeNone MCPAuthType = "none" // No authentication + MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.) + MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication (server-level, admin authenticates once) + MCPAuthTypePerUserOauth MCPAuthType = "per_user_oauth" // Per-user OAuth 2.0 authentication (each user authenticates individually) ) // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"client_id"` // Client ID - Name string `json:"name"` // Client name - IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client - ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) - ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) - StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) - AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) - OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) - State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) - Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) - InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) - ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. + ID string `json:"client_id"` // Client ID + Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) + ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) + OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) + State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) + Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) + AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) + ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools - ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + ToolsToAutoExecute WhiteList `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. // ToolsToAutoExecute semantics: // - ["*"] => all tools are auto-executed // - [] => no tools are auto-executed (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. - ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) - ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true. + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) + ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys + + // Discovered tools for per-user OAuth clients (persisted so they survive restart) + DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name + DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names } // NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any. @@ -147,6 +172,9 @@ func (c *MCPClientConfig) HttpHeaders(ctx context.Context, oauth2Provider OAuth2 for key, value := range c.Headers { headers[key] = value.GetValue() } + case MCPAuthTypePerUserOauth: + // Per-user OAuth: headers are injected per-call in executeToolInternal, not at connection level + return headers, nil case MCPAuthTypeNone: // No headers to add default: @@ -179,9 +207,10 @@ type MCPStdioConfig struct { type MCPConnectionState string const ( - MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use - MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected - MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used + MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use + MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected + MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used + MCPConnectionStatePendingTools MCPConnectionState = "pending_tools" // Connected but tools not yet populated ) // MCPClientState represents a connected MCP client with its configuration and tools. diff --git a/core/schemas/models.go b/core/schemas/models.go index 5a0e8588c1..32b82bb104 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -138,7 +138,7 @@ type Model struct { ID string `json:"id"` CanonicalSlug *string `json:"canonical_slug,omitempty"` Name *string `json:"name,omitempty"` - Deployment *string `json:"deployment,omitempty"` // Name of the actual deployment + Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN) Created *int64 `json:"created,omitempty"` ContextLength *int `json:"context_length,omitempty"` MaxInputTokens *int `json:"max_input_tokens,omitempty"` diff --git a/core/schemas/models_test.go b/core/schemas/models_test.go index 3e60fdda76..b9748952bd 100644 --- a/core/schemas/models_test.go +++ b/core/schemas/models_test.go @@ -94,7 +94,7 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { Error: &ErrorField{Message: "unauthorized"}, ExtraFields: BifrostErrorExtraFields{ Provider: "openai", - ModelRequested: "gpt-4", + OriginalModelRequested: "gpt-4", }, } keyStatus := KeyStatus{ @@ -112,6 +112,6 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { // Error fields other than key_statuses should be preserved dataStr := string(data) assert.Contains(t, dataStr, `"unauthorized"`) - assert.Contains(t, dataStr, `"model_requested":"gpt-4"`) + assert.Contains(t, dataStr, `"original_model_requested":"gpt-4"`) assert.Contains(t, dataStr, `"status_code":401`) } diff --git a/core/schemas/oauth.go b/core/schemas/oauth.go index 1a953c3cc6..7ff3bb9362 100644 --- a/core/schemas/oauth.go +++ b/core/schemas/oauth.go @@ -7,7 +7,7 @@ import ( // OauthProvider interface defines OAuth operations type OAuth2Provider interface { - // GetAccessToken retrieves the access token for a given oauth_config_id + // GetAccessToken retrieves the access token for a given oauth_config_id (server-level OAuth) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) // RefreshAccessToken refreshes the access token for a given oauth_config_id @@ -18,6 +18,31 @@ type OAuth2Provider interface { // RevokeToken revokes the OAuth token RevokeToken(ctx context.Context, oauthConfigID string) error + + // Per-user OAuth methods + + // GetUserAccessToken retrieves the access token for a per-user OAuth session. + // If the token is expired, it automatically attempts a refresh. + GetUserAccessToken(ctx context.Context, sessionToken string) (string, error) + + // GetUserAccessTokenByIdentity retrieves the upstream access token for a user + // identified by virtualKeyID, userID, or sessionToken (fallback), for a specific + // MCP client. Tokens looked up by identity persist across sessions. + GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error) + + // InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL. + // Returns (flow initiation details, session ID for polling, error). + InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*OAuth2FlowInitiation, string, error) + + // CompleteUserOAuthFlow handles the OAuth callback for a per-user flow. + // Returns the session token that the user should send on subsequent requests. + CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error) + + // RefreshUserAccessToken refreshes a per-user OAuth access token. + RefreshUserAccessToken(ctx context.Context, sessionToken string) error + + // RevokeUserToken revokes a per-user OAuth token and marks the session as revoked. + RevokeUserToken(ctx context.Context, sessionToken string) error } // OauthConfig represents OAuth client configuration diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index f9ea18a4b3..5e0d068718 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -313,9 +313,15 @@ type ObservabilityPlugin interface { // // Implementations should: // - Convert the trace to their backend's format - // - Send the trace to the backend (can be async) + // - Send the trace to the backend (can be async, but see retention note below) // - Handle errors gracefully (log and continue) // // The context passed is a fresh background context, not the request context. + // + // Retention: implementations MUST NOT retain the *Trace pointer after Inject + // returns. The caller releases the trace back to a sync.Pool immediately after + // Inject completes, so any background goroutine that still references it will + // race with pool reuse. If a plugin needs to forward the trace asynchronously, + // it must copy the data it needs before returning. Inject(ctx context.Context, trace *Trace) error } diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 034157e41d..b7c9702ae0 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -8,18 +8,18 @@ import ( ) const ( - DefaultMaxRetries = 0 - DefaultRetryBackoffInitial = 500 * time.Millisecond - DefaultRetryBackoffMax = 5 * time.Second + DefaultMaxRetries = 0 + DefaultRetryBackoffInitial = 500 * time.Millisecond + DefaultRetryBackoffMax = 5 * time.Second DefaultRequestTimeoutInSeconds = 30 - DefaultMaxConnDurationInSeconds = 300 // 5 minutes β€” forces connection recycling to prevent stale connections from NAT/LB silent drops - DefaultBufferSize = 5000 - DefaultConcurrency = 1000 - DefaultStreamBufferSize = 256 - DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk β€” if no data for this many seconds, bifrost closes the connection - DefaultMaxConnsPerHost = 5000 - MaxConnsPerHostUpperBound = 10000 - DefaultMaxIdleConnsPerHost = 40 + DefaultMaxConnDurationInSeconds = 300 // 5 minutes β€” forces connection recycling to prevent stale connections from NAT/LB silent drops + DefaultBufferSize = 5000 + DefaultConcurrency = 1000 + DefaultStreamBufferSize = 256 + DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk β€” if no data for this many seconds, bifrost closes the connection + DefaultMaxConnsPerHost = 5000 + MaxConnsPerHostUpperBound = 10000 + DefaultMaxIdleConnsPerHost = 40 ) // Pre-defined errors for provider operations @@ -52,18 +52,18 @@ const ( // - When marshaling to JSON: a time.Duration is converted to milliseconds type NetworkConfig struct { // BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama) - BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) - DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests - MaxRetries int `json:"max_retries"` // Maximum number of retries - RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) - RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) - InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections - CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections + BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests + MaxRetries int `json:"max_retries"` // Maximum number of retries + RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) + RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections + CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s) - MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) - EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) - BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) + EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) + BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") } // UnmarshalJSON customizes JSON unmarshaling for NetworkConfig. @@ -406,67 +406,6 @@ type CustomProviderConfig struct { RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock) } -type PricingOverrideMatchType string - -const ( - PricingOverrideMatchExact PricingOverrideMatchType = "exact" - PricingOverrideMatchWildcard PricingOverrideMatchType = "wildcard" - PricingOverrideMatchRegex PricingOverrideMatchType = "regex" -) - -// ProviderPricingOverride contains a partial pricing patch applied at lookup time. -// Any nil field falls back to the base pricing data. -type ProviderPricingOverride struct { - ModelPattern string `json:"model_pattern"` - MatchType PricingOverrideMatchType `json:"match_type"` - RequestTypes []RequestType `json:"request_types,omitempty"` - - // Basic token pricing - InputCostPerToken *float64 `json:"input_cost_per_token,omitempty"` - OutputCostPerToken *float64 `json:"output_cost_per_token,omitempty"` - - // Additional pricing for media - InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` - InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` - - // Character-based pricing - InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` - - // Pricing above 128k tokens - InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` - InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` - InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` - InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` - OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` - - // Pricing above 200k tokens - InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` - OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` - CacheCreationInputTokenCostAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"` - CacheReadInputTokenCostAbove200kTokens *float64 `json:"cache_read_input_token_cost_above_200k_tokens,omitempty"` - - // Cache and batch pricing - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"` - InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` - OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` - - // Image generation pricing - InputCostPerImageToken *float64 `json:"input_cost_per_image_token,omitempty"` - OutputCostPerImageToken *float64 `json:"output_cost_per_image_token,omitempty"` - InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` - OutputCostPerImage *float64 `json:"output_cost_per_image,omitempty"` - OutputCostPerImageAbove1024x1024Pixels *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"` - OutputCostPerImageAbove1024x1024PixelsPremium *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove2048x2048Pixels *float64 `json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"` - OutputCostPerImageAbove4096x4096Pixels *float64 `json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"` - OutputCostPerImageLowQuality *float64 `json:"output_cost_per_image_low_quality,omitempty"` - OutputCostPerImageMediumQuality *float64 `json:"output_cost_per_image_medium_quality,omitempty"` - OutputCostPerImageHighQuality *float64 `json:"output_cost_per_image_high_quality,omitempty"` - OutputCostPerImageAutoQuality *float64 `json:"output_cost_per_image_auto_quality,omitempty"` - CacheReadInputImageTokenCost *float64 `json:"cache_read_input_image_token_cost,omitempty"` -} - // IsOperationAllowed checks if a specific operation is allowed for this custom provider func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool { if cpc == nil || cpc.AllowedRequests == nil { @@ -482,14 +421,13 @@ type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` // Network configuration ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings // Logger instance, can be provided by the user or bifrost default logger is used if not provided - Logger Logger `json:"-"` - ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) - SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) - StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) - CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` - OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` - PricingOverrides []ProviderPricingOverride `json:"pricing_overrides,omitempty"` + Logger Logger `json:"-"` + ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) + SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) + StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) + CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` + OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` } // OpenAIConfig holds OpenAI-specific provider configuration. diff --git a/core/schemas/realtime.go b/core/schemas/realtime.go index e1e20d7bf4..ec4fd6789d 100644 --- a/core/schemas/realtime.go +++ b/core/schemas/realtime.go @@ -19,33 +19,75 @@ const ( // Server-to-client event types (received from the provider, forwarded to client) const ( - RTEventSessionCreated RealtimeEventType = "session.created" - RTEventSessionUpdated RealtimeEventType = "session.updated" - RTEventConversationCreated RealtimeEventType = "conversation.created" - RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" - RTEventConversationItemDone RealtimeEventType = "conversation.item.done" - RTEventResponseCreated RealtimeEventType = "response.created" - RTEventResponseDone RealtimeEventType = "response.done" - RTEventResponseTextDelta RealtimeEventType = "response.text.delta" - RTEventResponseTextDone RealtimeEventType = "response.text.done" - RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" - RTEventResponseAudioDone RealtimeEventType = "response.audio.done" - RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" - RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" - RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" - RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" - RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" - RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" - RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" - RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" - RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" - RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" - RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" - RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" - RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" - RTEventError RealtimeEventType = "error" + RTEventSessionCreated RealtimeEventType = "session.created" + RTEventSessionUpdated RealtimeEventType = "session.updated" + RTEventConversationCreated RealtimeEventType = "conversation.created" + RTEventConversationItemAdded RealtimeEventType = "conversation.item.added" + RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" + RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved" + RTEventConversationItemDone RealtimeEventType = "conversation.item.done" + RTEventResponseCreated RealtimeEventType = "response.created" + RTEventResponseDone RealtimeEventType = "response.done" + RTEventResponseTextDelta RealtimeEventType = "response.text.delta" + RTEventResponseTextDone RealtimeEventType = "response.text.done" + RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" + RTEventResponseAudioDone RealtimeEventType = "response.audio.done" + RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" + RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" + RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" + RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" + RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" + RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" + RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated" + RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" + RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" + RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" + RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" + RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" + RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" + RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" + RTEventError RealtimeEventType = "error" ) +// IsRealtimeConversationItemEventType reports whether the event carries a +// canonical conversation item payload after provider translation. +func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool { + switch eventType { + case RTEventConversationItemCreate, + RTEventConversationItemAdded, + RTEventConversationItemCreated, + RTEventConversationItemRetrieved, + RTEventConversationItemDone: + return true + default: + return false + } +} + +// IsRealtimeUserInputEvent reports whether the event represents a finalized +// user input item in the canonical Bifrost realtime schema. +func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Role == "user" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeToolOutputEvent reports whether the event represents a finalized +// tool output item in the canonical Bifrost realtime schema. +func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Type == "function_call_output" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized +// input-audio transcript in the canonical Bifrost realtime schema. +func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool { + return event != nil && event.Type == RTEventInputAudioTransCompleted +} + // BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events. // Provider converters translate between this format and the provider-native protocol. type BifrostRealtimeEvent struct { @@ -58,36 +100,42 @@ type BifrostRealtimeEvent struct { Audio []byte `json:"audio,omitempty"` Error *RealtimeError `json:"error,omitempty"` + // ExtraParams preserves provider-specific top-level event fields that are not + // promoted into the common Bifrost schema. + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` + // RawData preserves the original provider event for pass-through or debugging. RawData json.RawMessage `json:"raw_data,omitempty"` } // RealtimeSession describes session configuration for the Realtime connection. type RealtimeSession struct { - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Instructions string `json:"instructions,omitempty"` - Voice string `json:"voice,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` - TurnDetection json.RawMessage `json:"turn_detection,omitempty"` - InputAudioFormat string `json:"input_audio_format,omitempty"` - OutputAudioType string `json:"output_audio_type,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Voice string `json:"voice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` + TurnDetection json.RawMessage `json:"turn_detection,omitempty"` + InputAudioFormat string `json:"input_audio_format,omitempty"` + OutputAudioType string `json:"output_audio_type,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeItem represents a conversation item in the Realtime protocol. type RealtimeItem struct { - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Role string `json:"role,omitempty"` - Status string `json:"status,omitempty"` - Content json.RawMessage `json:"content,omitempty"` - Name string `json:"name,omitempty"` - CallID string `json:"call_id,omitempty"` - Arguments string `json:"arguments,omitempty"` - Output string `json:"output,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Status string `json:"status,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + CallID string `json:"call_id,omitempty"` + Arguments string `json:"arguments,omitempty"` + Output string `json:"output,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeDelta carries incremental content for streaming events. @@ -103,10 +151,28 @@ type RealtimeDelta struct { // RealtimeError describes an error from the Realtime API. type RealtimeError struct { - Type string `json:"type,omitempty"` - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` - Param string `json:"param,omitempty"` + Type string `json:"type,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param string `json:"param,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` +} + +// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint +// shape the client called so providers can preserve versioned behavior. +type RealtimeSessionEndpointType string + +const ( + RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets" + RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions" +) + +// RealtimeSessionRoute describes a provider-registered public route for +// ephemeral-token creation. +type RealtimeSessionRoute struct { + Path string + EndpointType RealtimeSessionEndpointType + DefaultProvider ModelProvider } // RealtimeProvider is an optional interface that providers can implement to @@ -116,6 +182,129 @@ type RealtimeProvider interface { SupportsRealtimeAPI() bool RealtimeWebSocketURL(key Key, model string) string RealtimeHeaders(key Key) map[string]string + // SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange. + SupportsRealtimeWebRTC() bool + // ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange. + // The provider owns the HTTP specifics (URL, headers, body format). + // session may be nil if the signaling format doesn't include session config. + ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error) ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error) + // ShouldStartRealtimeTurn reports whether the canonical client-side event + // should start pre-hooks. Providers without an explicit turn-start signal + // return false and rely on finalize-time fallback hooks. + ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool + // RealtimeTurnFinalEvent returns the canonical provider event that completes + // a turn and should trigger post-hooks. + RealtimeTurnFinalEvent() RealtimeEventType + RealtimeWebRTCDataChannelLabel() string + RealtimeWebSocketSubprotocol() string + ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool + ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool +} + +// RealtimeLegacyWebRTCProvider is an optional interface for providers that +// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime). +// Only checked for legacy integration routes via type assertion. +// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP +// but targets the provider's legacy/beta endpoint. +type RealtimeLegacyWebRTCProvider interface { + ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError) +} + +// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from +// their native wire payloads without coupling handlers to a specific protocol. +type RealtimeUsageExtractor interface { + ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage + ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage +} + +// RealtimeSessionProvider is an optional interface for providers that can mint +// short-lived client secrets for browser/client-side Realtime connections. +// Checked via type assertion: provider.(RealtimeSessionProvider). +type RealtimeSessionProvider interface { + CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError) +} + +// ParseRealtimeEvent decodes a client/provider realtime event while preserving +// unknown top-level fields in ExtraParams for provider-specific round-tripping. +func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) { + type realtimeEventAlias struct { + Type RealtimeEventType `json:"type"` + EventID string `json:"event_id,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Delta *RealtimeDelta `json:"delta,omitempty"` + Audio []byte `json:"audio,omitempty"` + Error *RealtimeError `json:"error,omitempty"` + } + + var alias realtimeEventAlias + if err := Unmarshal(raw, &alias); err != nil { + return nil, err + } + + event := &BifrostRealtimeEvent{ + Type: alias.Type, + EventID: alias.EventID, + Session: alias.Session, + Item: alias.Item, + Delta: alias.Delta, + Audio: alias.Audio, + Error: alias.Error, + } + + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, err + } + savedSession := root["session"] + savedItem := root["item"] + savedError := root["error"] + for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} { + delete(root, key) + } + if len(root) > 0 { + event.ExtraParams = root + } + if event.Session != nil { + var sessionRoot map[string]json.RawMessage + if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil { + for _, key := range []string{ + "id", "model", "modalities", "instructions", "voice", "temperature", + "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools", + } { + delete(sessionRoot, key) + } + if len(sessionRoot) > 0 { + event.Session.ExtraParams = sessionRoot + } + } + } + if event.Item != nil { + var itemRoot map[string]json.RawMessage + if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil { + for _, key := range []string{ + "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output", + } { + delete(itemRoot, key) + } + if len(itemRoot) > 0 { + event.Item.ExtraParams = itemRoot + } + } + } + if event.Error != nil { + var errorRoot map[string]json.RawMessage + if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil { + for _, key := range []string{"type", "code", "message", "param"} { + delete(errorRoot, key) + } + if len(errorRoot) > 0 { + event.Error.ExtraParams = errorRoot + } + } + } + + return event, nil } diff --git a/core/schemas/realtime_client_secrets.go b/core/schemas/realtime_client_secrets.go new file mode 100644 index 0000000000..ae97b573a1 --- /dev/null +++ b/core/schemas/realtime_client_secrets.go @@ -0,0 +1,66 @@ +package schemas + +import ( + "bytes" + "encoding/json" + "strings" +) + +// ParseRealtimeClientSecretBody parses a realtime client-secret request body +// into a mutable raw JSON map while preserving unknown fields. +func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) { + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err) + } + return root, nil +} + +// ExtractRealtimeClientSecretModel extracts the model from either session.model +// or the legacy top-level model field. +func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) { + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) { + var session map[string]json.RawMessage + if err := Unmarshal(sessionJSON, &session); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err) + } + if modelJSON, ok := session["model"]; ok { + var sessionModel string + if err := Unmarshal(modelJSON, &sessionModel); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err) + } + if strings.TrimSpace(sessionModel) != "" { + return strings.TrimSpace(sessionModel), nil + } + } + } + + if modelJSON, ok := root["model"]; ok { + var model string + if err := Unmarshal(modelJSON, &model); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err) + } + if strings.TrimSpace(model) != "" { + return strings.TrimSpace(model), nil + } + } + + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil) +} + +// NewRealtimeClientSecretBodyError builds a standard invalid-request style error +// for HTTP realtime client-secret request parsing/validation. +func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError { + return &BifrostError{ + IsBifrostError: false, + StatusCode: Ptr(status), + Error: &ErrorField{ + Type: Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: BifrostErrorExtraFields{ + RequestType: RealtimeRequest, + }, + } +} diff --git a/core/schemas/realtime_client_secrets_test.go b/core/schemas/realtime_client_secrets_test.go new file mode 100644 index 0000000000..dfd8f8b1d3 --- /dev/null +++ b/core/schemas/realtime_client_secrets_test.go @@ -0,0 +1,40 @@ +package schemas + +import ( + "encoding/json" + "testing" +) + +func TestExtractRealtimeClientSecretModel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "openai/gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview") + } +} + +func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } +} diff --git a/core/schemas/realtime_test.go b/core/schemas/realtime_test.go new file mode 100644 index 0000000000..69e9e403c8 --- /dev/null +++ b/core/schemas/realtime_test.go @@ -0,0 +1,68 @@ +package schemas + +import "testing" + +func TestIsRealtimeConversationItemEventType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + eventType RealtimeEventType + want bool + }{ + {name: "create", eventType: RTEventConversationItemCreate, want: true}, + {name: "added", eventType: RTEventConversationItemAdded, want: true}, + {name: "created", eventType: RTEventConversationItemCreated, want: true}, + {name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true}, + {name: "done", eventType: RTEventConversationItemDone, want: true}, + {name: "response done", eventType: RTEventResponseDone, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want { + t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want) + } + }) + } +} + +func TestRealtimeCanonicalEventClassifiers(t *testing.T) { + t.Parallel() + + userEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemAdded, + Item: &RealtimeItem{ + Role: "user", + Type: "message", + }, + } + if !IsRealtimeUserInputEvent(userEvent) { + t.Fatal("expected conversation.item.added user event to be classified as realtime user input") + } + if IsRealtimeToolOutputEvent(userEvent) { + t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output") + } + + toolEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemRetrieved, + Item: &RealtimeItem{ + Type: "function_call_output", + }, + } + if !IsRealtimeToolOutputEvent(toolEvent) { + t.Fatal("expected function_call_output item to be classified as realtime tool output") + } + if IsRealtimeUserInputEvent(toolEvent) { + t.Fatal("did not expect function_call_output item to be classified as realtime user input") + } + + transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted} + if !IsRealtimeInputTranscriptEvent(transcriptEvent) { + t.Fatal("expected input audio transcription completion to be classified as transcript event") + } + if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) { + t.Fatal("did not expect input audio transcription delta to be classified as transcript event") + } +} diff --git a/core/schemas/trace.go b/core/schemas/trace.go index 9a69980d3c..d6862d4d4e 100644 --- a/core/schemas/trace.go +++ b/core/schemas/trace.go @@ -8,6 +8,7 @@ import ( // Trace represents a distributed trace that captures the full lifecycle of a request type Trace struct { + RequestID string // Request ID for the trace TraceID string // Unique identifier for this trace ParentID string // Parent trace ID from incoming W3C traceparent header RootSpan *Span // The root span of this trace @@ -15,6 +16,7 @@ type Trace struct { StartTime time.Time // When the trace started EndTime time.Time // When the trace completed Attributes map[string]any // Additional attributes for the trace + PluginLogs []PluginLogEntry // Plugin log entries accumulated during request processing mu sync.Mutex // Mutex for thread-safe span operations } @@ -37,15 +39,49 @@ func (t *Trace) GetSpan(spanID string) *Span { return nil } +// GetRequestID retrieves the request ID from the trace +func (t *Trace) GetRequestID() string { + t.mu.Lock() + defer t.mu.Unlock() + return t.RequestID +} + +// SetRequestID sets the request ID for the trace +func (t *Trace) SetRequestID(requestID string) { + t.mu.Lock() + defer t.mu.Unlock() + t.RequestID = requestID +} + // Reset clears the trace for reuse from pool func (t *Trace) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + t.RequestID = "" t.TraceID = "" t.ParentID = "" t.RootSpan = nil + for i := range t.Spans { + t.Spans[i] = nil + } t.Spans = t.Spans[:0] t.StartTime = time.Time{} t.EndTime = time.Time{} t.Attributes = nil + for i := range t.PluginLogs { + t.PluginLogs[i] = PluginLogEntry{} + } + t.PluginLogs = t.PluginLogs[:0] +} + +// AppendPluginLogs appends plugin log entries to the trace in a thread-safe manner. +func (t *Trace) AppendPluginLogs(logs []PluginLogEntry) { + if len(logs) == 0 { + return + } + t.mu.Lock() + t.PluginLogs = append(t.PluginLogs, logs...) + t.mu.Unlock() } // Span represents a single operation within a trace diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go index 06f6487f8c..23c5d4cc4c 100644 --- a/core/schemas/tracer.go +++ b/core/schemas/tracer.go @@ -14,7 +14,8 @@ type SpanHandle interface{} // This is the return type for tracer's streaming accumulation methods. type StreamAccumulatorResult struct { RequestID string // Request ID - Model string // Model used + RequestedModel string // Original model requested by the caller + ResolvedModel string // Actual model used by the provider (equals RequestedModel when no alias mapping exists) Provider ModelProvider // Provider used Status string // Status of the stream Latency int64 // Latency in milliseconds @@ -38,7 +39,8 @@ type StreamAccumulatorResult struct { type Tracer interface { // CreateTrace creates a new trace with optional parent ID and returns the trace ID. // The parentID can be extracted from W3C traceparent headers for distributed tracing. - CreateTrace(parentID string) string + // The requestID is optional and can be used to identify the request. + CreateTrace(parentID string, requestID ...string) string // EndTrace completes a trace and returns the trace data for observation/export. // After this call, the trace is removed from active tracking and returned for cleanup. @@ -68,7 +70,7 @@ type Tracer interface { // PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. // This includes output messages, tokens, usage stats, and error information if present. - PopulateLLMResponseAttributes(handle SpanHandle, resp *BifrostResponse, err *BifrostError) + PopulateLLMResponseAttributes(ctx *BifrostContext, handle SpanHandle, resp *BifrostResponse, err *BifrostError) // StoreDeferredSpan stores a span handle for later completion (used for streaming requests). // The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. @@ -111,6 +113,14 @@ type Tracer interface { // The ctx parameter must contain the stream end indicator for proper final chunk detection. ProcessStreamingChunk(traceID string, isFinalChunk bool, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult + // AttachPluginLogs appends plugin log entries to the trace identified by traceID. + // Thread-safe. Should be called after plugin hooks complete, before trace completion. + AttachPluginLogs(traceID string, logs []PluginLogEntry) + + // CompleteAndFlushTrace ends a trace, exports it to observability plugins, and + // releases the trace resources. Used by transports that bypass normal HTTP trace completion. + CompleteAndFlushTrace(traceID string) + // Stop releases resources associated with the tracer. // Should be called during shutdown to stop background goroutines. Stop() @@ -121,7 +131,7 @@ type Tracer interface { type NoOpTracer struct{} // CreateTrace returns an empty string (no trace created). -func (n *NoOpTracer) CreateTrace(_ string) string { return "" } +func (n *NoOpTracer) CreateTrace(_ string, _ ...string) string { return "" } // EndTrace returns nil (no trace to end). func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil } @@ -144,7 +154,7 @@ func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {} func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {} // PopulateLLMResponseAttributes does nothing. -func (n *NoOpTracer) PopulateLLMResponseAttributes(_ SpanHandle, _ *BifrostResponse, _ *BifrostError) { +func (n *NoOpTracer) PopulateLLMResponseAttributes(_ *BifrostContext, _ SpanHandle, _ *BifrostResponse, _ *BifrostError) { } // StoreDeferredSpan does nothing. @@ -176,6 +186,12 @@ func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, return nil } +// AttachPluginLogs does nothing. +func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {} + +// CompleteAndFlushTrace does nothing. +func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {} + // Stop does nothing. func (n *NoOpTracer) Stop() {} diff --git a/core/schemas/transcriptions.go b/core/schemas/transcriptions.go index 7308714ed5..1cd801be98 100644 --- a/core/schemas/transcriptions.go +++ b/core/schemas/transcriptions.go @@ -14,15 +14,37 @@ func (r *BifrostTranscriptionRequest) GetRawRequestBody() []byte { } type BifrostTranscriptionResponse struct { - Duration *float64 `json:"duration,omitempty"` // Duration in seconds - Language *string `json:"language,omitempty"` // e.g., "english" - LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` - Segments []TranscriptionSegment `json:"segments,omitempty"` - Task *string `json:"task,omitempty"` // e.g., "transcribe" - Text string `json:"text"` - Usage *TranscriptionUsage `json:"usage,omitempty"` - Words []TranscriptionWord `json:"words,omitempty"` - ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + Duration *float64 `json:"duration,omitempty"` // Duration in seconds + Language *string `json:"language,omitempty"` // e.g., "english" + LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` + Segments []TranscriptionSegment `json:"segments,omitempty"` + Task *string `json:"task,omitempty"` // e.g., "transcribe" + Text string `json:"text"` + Usage *TranscriptionUsage `json:"usage,omitempty"` + Words []TranscriptionWord `json:"words,omitempty"` + ResponseFormat *string `json:"-"` // Set by provider for non-JSON formats (text, srt, vtt); used by integration response converters + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` +} + +func (r *BifrostTranscriptionResponse) BackfillParams(req *BifrostTranscriptionRequest) { + if r == nil || req == nil || req.Params == nil || req.Params.ResponseFormat == nil { + return + } + r.ResponseFormat = req.Params.ResponseFormat +} + +// IsPlainTextTranscriptionFormat returns true if the given response format +// produces a plain-text response body (not JSON). +func IsPlainTextTranscriptionFormat(format *string) bool { + if format == nil { + return false + } + switch *format { + case "text", "srt", "vtt": + return true + default: + return false + } } type TranscriptionInput struct { @@ -31,17 +53,17 @@ type TranscriptionInput struct { } type TranscriptionParameters struct { - Language *string `json:"language,omitempty"` - Prompt *string `json:"prompt,omitempty"` - ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" - Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) - TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json - Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs) - Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini - MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace - MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace - MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace - MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" + Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) + TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json + Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs) + Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini + MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace + MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace + MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace + MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace // Elevenlabs-specific fields AdditionalFormats []TranscriptionAdditionalFormat `json:"additional_formats,omitempty"` @@ -132,4 +154,3 @@ type BifrostTranscriptionStreamResponse struct { Usage *TranscriptionUsage `json:"usage,omitempty"` ExtraFields BifrostResponseExtraFields `json:"extra_fields"` } - diff --git a/core/utils.go b/core/utils.go index ed8f40ebf4..12d86e2508 100644 --- a/core/utils.go +++ b/core/utils.go @@ -11,6 +11,7 @@ import ( "math/rand" "net" "net/url" + "slices" "strings" "time" @@ -86,19 +87,19 @@ func Ptr[T any](v T) *T { } // providerRequiresKey returns true if the given provider requires an API key for authentication. -// Some providers like Ollama, SGL, and vLLM are keyless and don't require API keys. -func providerRequiresKey(providerKey schemas.ModelProvider, customConfig *schemas.CustomProviderConfig) bool { +func providerRequiresKey(customConfig *schemas.CustomProviderConfig) bool { // Keyless custom providers are not allowed for Bedrock. if customConfig != nil && customConfig.IsKeyLess && customConfig.BaseProviderType != schemas.Bedrock { return false } - return !IsKeylessProvider(providerKey) + return true } -// canProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. -// Some providers like Vertex and Bedrock have their credentials in additional key configs.. +// CanProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. +// Some providers like Vertex and Bedrock have their credentials in additional key configs. +// Ollama and SGL are keyless (API Key is optional) but use per-key server URLs. func CanProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure + return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure || providerKey == schemas.Ollama || providerKey == schemas.SGL } func isKeySkippingAllowed(providerKey schemas.ModelProvider) bool { @@ -131,6 +132,51 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError { return nil } +// validateKey validates the given key. +func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) bool { + // Valid the key for the provider + switch providerKey { + case schemas.Azure: + if key.AzureKeyConfig == nil { + return false + } + if key.AzureKeyConfig.Endpoint.GetValue() == "" { + return false + } + case schemas.Bedrock: + // Key is valid if either: + // 1. BedrockKeyConfig is provided + // 2. Value is provided and is not empty + if key.BedrockKeyConfig == nil { + if key.Value.GetValue() == "" { + return false + } + key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} + } + case schemas.Vertex: + if key.VertexKeyConfig == nil { + return false + } + case schemas.Replicate: + if key.ReplicateKeyConfig == nil { + return false + } + case schemas.VLLM: + if key.VLLMKeyConfig == nil || key.VLLMKeyConfig.URL.GetValue() == "" { + return false + } + case schemas.Ollama: + if key.OllamaKeyConfig == nil || key.OllamaKeyConfig.URL.GetValue() == "" { + return false + } + case schemas.SGL: + if key.SGLKeyConfig == nil || key.SGLKeyConfig.URL.GetValue() == "" { + return false + } + } + return true +} + // IsRateLimitErrorMessage checks if an error message indicates a rate limit issue func IsRateLimitErrorMessage(errorMessage string) bool { if errorMessage == "" { @@ -175,7 +221,7 @@ func newBifrostErrorFromMsg(message string) *schemas.BifrostError { // newBifrostCtxDoneError creates a BifrostError from a cancelled/expired context. // It distinguishes DeadlineExceeded (504 RequestTimedOut) from Canceled (499 RequestCancelled). -func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, requestType schemas.RequestType, stage string) *schemas.BifrostError { +func newBifrostCtxDoneError(ctx *schemas.BifrostContext, stage string) *schemas.BifrostError { var statusCode int var errorType string var message string @@ -199,11 +245,6 @@ func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelP Message: message, Error: ctx.Err(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: provider, - ModelRequested: model, - }, } } @@ -230,6 +271,7 @@ func newBifrostMessageChan(message *schemas.BifrostResponse) chan *schemas.Bifro func clearCtxForFallback(ctx *schemas.BifrostContext) { ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID) ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName) + ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys) } var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} { @@ -261,11 +303,6 @@ func IsStandardProvider(providerKey schemas.ModelProvider) bool { return ok } -// IsKeylessProvider reports whether providerKey is a keyless provider. -func IsKeylessProvider(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Ollama || providerKey == schemas.SGL -} - // IsStreamRequestType returns true if the given request type is a stream request. func IsStreamRequestType(reqType schemas.RequestType) bool { return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.ImageEditStreamRequest || reqType == schemas.PassthroughStreamRequest || reqType == schemas.WebSocketResponsesRequest || reqType == schemas.RealtimeRequest @@ -336,14 +373,14 @@ func IsFinalChunk(ctx *schemas.BifrostContext) bool { return false } -// GetResponseFields extracts the request type, provider, and model from the result or error -func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, model string) { +// GetResponseFields extracts the request type, provider, original model, and resolved model from the result or error. +func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, originalModel string, resolvedModel string) { if result != nil { extraFields := result.GetExtraFields() - return extraFields.RequestType, extraFields.Provider, extraFields.ModelRequested + return extraFields.RequestType, extraFields.Provider, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed } if err != nil { - return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.ModelRequested + return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.OriginalModelRequested, err.ExtraFields.ResolvedModelUsed } return } @@ -544,3 +581,17 @@ func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model } return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) } + +// isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt. +// It normalises hyphenated variants (e.g. "erase-object") to underscore form before matching. +func isPromptOptionalImageEditType(t *string) bool { + if t == nil { + return false + } + normalized := strings.ToLower(strings.TrimSpace(*t)) + normalized = strings.ReplaceAll(normalized, "-", "_") + return slices.Contains( + []string{"background_removal", "remove_background", "remove_bg", "erase_object", "upscale_fast"}, + normalized, + ) +} diff --git a/core/version b/core/version index ae1d35b779..8e03717dca 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.4.17 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/docs/changelogs/v1.5.0-prerelease1.mdx b/docs/changelogs/v1.5.0-prerelease1.mdx index 2de8b30e52..8727232fb7 100644 --- a/docs/changelogs/v1.5.0-prerelease1.mdx +++ b/docs/changelogs/v1.5.0-prerelease1.mdx @@ -39,6 +39,8 @@ description: "v1.5.0-prerelease1 changelog - 2026-04-01" ## Breaking Changes in This Release +This prerelease introduces 9 breaking changes. See the **[v1.5.0 Migration Guide](/migration-guides/v1.5.0)** for full before/after examples, automatic migration details, and a step-by-step checklist. + | # | Breaking Change | Affected | |---|---|---| | [1](/migration-guides/v1.5.0#breaking-change-1-empty-array-now-means-deny-all) | Empty array (`[]`) now means "deny all" on all allow-list fields | `config.json`, REST API | diff --git a/docs/docs.json b/docs/docs.json index 6df5da9cf5..4d13f57643 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -149,6 +149,7 @@ }, "providers/routing-rules", "providers/provider-routing", + "providers/aliasing-models", "providers/reasoning", "providers/performance", "providers/custom-providers", @@ -215,7 +216,8 @@ "group": "Prompt Repository", "icon": "folder", "pages": [ - "features/prompt-repository/playground" + "features/prompt-repository/playground", + "features/prompt-repository/prompts-plugin" ] }, { @@ -486,7 +488,6 @@ "icon": "rocket", "pages": [ "changelogs/v1.5.0-prerelease1", - "changelogs/v1.4.20", { "group": "March 2026", "pages": [ diff --git a/docs/enterprise/setting-up-okta.mdx b/docs/enterprise/setting-up-okta.mdx index 7b6f25f0bc..6435459f7c 100644 --- a/docs/enterprise/setting-up-okta.mdx +++ b/docs/enterprise/setting-up-okta.mdx @@ -13,7 +13,7 @@ This guide walks you through configuring Okta as your identity provider for Bifr - An Okta organization with admin access - Bifrost Enterprise deployed and accessible - The redirect URI for your Bifrost instance (e.g., `https://your-bifrost-domain.com/login`) - +- Ensure you have created all the [roles in Bifrost](/enterprise/rbac) that you are aiming to map to with Okta. --- ## Step 1: Create an OIDC Application @@ -71,39 +71,12 @@ Configure the following settings for your application: --- -## Step 3: Configure Authorization Server (optional) +## Step 3: Create Custom Role Attribute -The default authorization server (`/oauth2/default`) is available to all Okta plans and **supports custom claims**, including role claims. The API Access Management paid add-on is only required to create additional custom authorization servers beyond the default. +You can use both roles and/or groups for assigning roles to users. You can learn more about [RBAC](/enterprise/rbac) docs. Roles take precedence over groups in role assignment. -Bifrost uses Okta's Authorization Server to issue tokens. You have three options: - -1. **Use `/oauth2/default` with role claims (recommended)** β€” Complete Steps 4-7 to configure custom role claims on the default authorization server. This enables automatic RBAC synchronization. - -2. **Use `/oauth2/default` without role claims** β€” Skip Steps 4-7. The first user to sign in automatically receives the Admin role and can manage RBAC for all subsequent users through the Bifrost dashboard. - -3. **Skip Step 3 entirely** β€” Authorization is not configured through Okta. You'll need an alternative authentication mechanism. - -### Configuring the Authorization Server - -1. Navigate to **Security** β†’ **API** -2. Click on **Authorization Servers** - - - Okta Authorization Servers - - -3. Note the **Issuer URI** for your authorization server (e.g., `https://your-domain.okta.com/oauth2/default`) - - -The Issuer URI is used as the `issuerUrl` in your Bifrost configuration. Make sure to use the full URL including `/oauth2/default` (or your custom authorization server path). - - ---- - -## Step 4: Create Custom Role Attribute - To map Okta users to Bifrost roles (Admin, Developer, Viewer), you need to create a custom attribute. 1. Navigate to **Directory** β†’ **Profile Editor** @@ -133,7 +106,7 @@ To map Okta users to Bifrost roles (Admin, Developer, Viewer), you need to creat --- -## Step 5: Add Role Claim to Tokens +## Step 4: Add Role Claim to Tokens Configure the authorization server to include the role in the access token. @@ -164,11 +137,11 @@ If you named your custom attribute differently, update the Value expression acco --- -## Step 6: Configure Groups for Team and Role Synchronization +## Step 5: Configure Groups Bifrost can automatically sync Okta groups for two purposes: - **Team synchronization** β€” Groups are synced as Bifrost teams -- **Role mapping** β€” Groups can be mapped to Bifrost roles (Admin, Developer, Viewer) using Group-to-Role Mappings in the Bifrost UI +- **Role mapping** β€” Groups can be mapped to Bifrost roles (Admin, Developer, Viewer) using Group-to-Role Mappings in the Bifrost UI. ### Create Groups in Okta @@ -191,31 +164,6 @@ Use a consistent naming convention for your groups. This makes it easier to conf ### Add Groups Claim to Tokens -You have two options for configuring the groups claim. Choose the one that best fits your Okta plan and requirements. - -#### Option A: Using App-Level Groups Claim (All Okta Plans) - -This approach configures the groups claim directly in your application's settings and works with all Okta plans, including free tiers. - -1. Navigate to your application's **Sign On** tab -2. Scroll down to the **OpenID Connect ID Token** section -3. Click **Edit** to modify the settings -4. Configure the **Groups claim filter**: - - **Groups claim type**: Filter - - **Groups claim filter**: Set a claim name (e.g., `groups`) and filter condition (e.g., "Starts with" `bifrost-staging`) - - - Application Groups claim configuration - - -5. Click **Save** - - -The filter ensures only relevant groups are included in the token. Adjust the filter condition based on your group naming convention. - - -#### Option B: Using Authorization Server Groups Claim - This approach adds the groups claim through your authorization server, providing more flexibility for complex configurations. 1. Navigate to **Security** β†’ **API** β†’ **Authorization Servers** @@ -235,25 +183,9 @@ Configure the groups claim: 5. Click **Create** -You can also configure an additional groups claim in the application's Sign On settings: - -1. Navigate to your application's **Sign On** tab - - - Application Sign On configuration - - -2. Under **OpenID Connect ID Token**, configure: - - **Groups claim type**: Expression - - **Groups claim expression**: `Arrays.flatten(Groups.startsWith("OKTA", "bifrost", 100))` - - -Adjust the group filter expression based on your naming convention. The example above includes groups starting with "bifrost". - - --- -## Step 7: Assign Users to the Application +## Step 6: Assign Users to the Application 1. Navigate to your application's **Assignments** tab @@ -263,7 +195,9 @@ Adjust the group filter expression based on your naming convention. The example 2. Click **Assign** β†’ **Assign to People** or **Assign to Groups** -3. For each user, set their **bifrostRole**: +### For Assigning Roles + +For each user, set their **bifrostRole** (if you are planning to do role-level mapping): Assign custom role to user @@ -277,6 +211,22 @@ Role claims are available only when you configure custom claims on your authoriz --- +## Step 7: Create API token for bulk user and team sync + +To create an API token, navigate to **Security** β†’ **API** β†’ **Tokens**. + + +Okta API tokens screen + + +1. Click on "Create token" + + + Create token dialog in Okta + + +2. Copy token to be used in the next step. + ## Step 8: Configure Bifrost Now configure Bifrost to use Okta as the identity provider. @@ -297,9 +247,9 @@ Now configure Bifrost to use Okta as the identity provider. 4. Toggle **Enabled** to activate the provider 5. Click **Save Configuration** -### Group-to-Role Mappings (Optional) +### Group-to-Role Mappings -If you configured groups in Okta (Step 6), you can map Okta group names directly to Bifrost roles. This is an alternative to using custom role claims (Steps 4-5) and works with all Okta plans. +If you configured groups in Okta (Step 5), you can map Okta group names directly to Bifrost roles. This is an alternative to using custom role claims (Steps 3-4) and works with all Okta plans. 1. In the User Provisioning configuration, scroll down to **Group-to-Role Mappings** 2. Click **Add Mapping** diff --git a/docs/features/litellm-compat.mdx b/docs/features/litellm-compat.mdx index b26f94cd7e..51cd26dcd9 100644 --- a/docs/features/litellm-compat.mdx +++ b/docs/features/litellm-compat.mdx @@ -125,7 +125,8 @@ When either transformation is applied: - `extra_fields.litellm_compat`: Set to `true` - `extra_fields.provider`: The provider that handled the request - `extra_fields.request_type`: Reflects the original request type -- `extra_fields.model_requested`: The originally requested model +- `extra_fields.original_model_requested`: The originally requested model +- `extra_fields.resolved_model_used`: The actual provider API identifier used (equals original_model_requested when no alias mapping exists) ### Error Handling diff --git a/docs/features/prompt-repository/playground.mdx b/docs/features/prompt-repository/playground.mdx index caffea4080..30c2b0df9a 100644 --- a/docs/features/prompt-repository/playground.mdx +++ b/docs/features/prompt-repository/playground.mdx @@ -229,3 +229,9 @@ With sessions you can: - Switch between past experiments ![Sessions](../../media/prompt-repo-sessions.png) + +--- + +## Using prompts in production + +To attach committed versions to **Chat Completions** or **Responses** requests through the gateway (HTTP headers, merging, and caching behavior), see the [Prompts plugin](/features/prompt-repository/prompts-plugin). diff --git a/docs/features/prompt-repository/prompts-plugin.mdx b/docs/features/prompt-repository/prompts-plugin.mdx new file mode 100644 index 0000000000..50fd6def32 --- /dev/null +++ b/docs/features/prompt-repository/prompts-plugin.mdx @@ -0,0 +1,134 @@ +--- +title: "Prompts plugin" +description: "Use committed prompt templates from the Prompt Repository on inference requests via HTTP headers or custom resolvers." +icon: "puzzle-piece" +--- + +## Overview + +The **Prompts** plugin connects the [Prompt Repository](/features/prompt-repository/playground) to inference. It loads committed prompt versions from the config store and **prepends** their messages to **Chat Completions** and **Responses** requests. It also **merges model parameters** from the stored version with the incoming request (request values take precedence). + +**What it does:** + +- Resolves which prompt and version to apply per request (default: HTTP headers). +- Injects the version’s message history **before** the client’s messages. +- Applies the version’s `model` parameters as defaults, then overrides with whatever the client sent for the same parameters. + +--- + +## Prerequisites + +- **Config store** with Prompt Repository tables (typically **PostgreSQL**). File-backed config alone does not store prompts. +- Prompts authored and **committed as versions** in the UI or via the `/api/prompt-repo/...` HTTP API (see `docs/openapi/openapi.yaml` in the repository). +- A **prompt ID** (UUID) for each prompt you reference at runtime. You can read it from the repository API or the playground. + +--- + +## How it works + +```mermaid +flowchart TB + Client([Client]) --> Gateway[Bifrost HTTP] + Gateway --> PreHook["HTTP transport pre-hook:
copy bf-prompt-id / bf-prompt-version to context"] + PreHook --> PreLLM["PreLLM hook:
resolve version, merge params,
prepend template messages"] + PreLLM --> Provider[Provider] +``` + +1. **Transport (HTTP):** Incoming headers `bf-prompt-id` and `bf-prompt-version` are copied onto the Bifrost context (header name matching is case-insensitive). +2. **Resolve:** The plugin looks up the prompt and the requested version. If **`bf-prompt-version` is omitted**, the prompt’s **latest committed version** is used. +3. **Parameters:** Version `model` parameters are merged into the request; any field already set on the request wins. +4. **Messages:** Messages from the committed version are **prepended** to `messages` (chat) or `input` (responses). Your request body adds the user turn(s) after the template. + +If the prompt ID is missing, the plugin does nothing and the request passes through unchanged. + +--- + +## HTTP headers (gateway) + +| Header | Required | Description | +|--------|----------|-------------| +| `bf-prompt-id` | Yes, to enable injection | UUID of the prompt in the repository. | +| `bf-prompt-version` | No | **Integer version number** (e.g. `3` for v3). If omitted, the **latest** committed version for that prompt is used. | + +Invalid or unknown IDs / versions are logged as warnings; the request is **not** failed by the plugin (it proceeds without template injection). + +--- + +## Example: Chat Completions + +Use the same JSON body as a normal chat request. Only the headers select the template. + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "bf-prompt-id: YOUR-PROMPT-UUID" \ + -H "x-bf-vk: sk-bf-your-virtual-key" \ + -d '{ + "model": "openai/gpt-5.4", + "messages": [ + { + "role": "user", + "content": "Tell me about Bifrost Gateway?" + } + ] + }' +``` + +![Commit Version with Stream enabled in the playground](../../media/prompt-plugin-version-commit.png) + +When you commit a version from the playground, **Stream** is saved in that version’s model parameters. The example `curl` above does not set `"stream": true` in the JSON body, but if the committed version was saved with streaming enabled (as in the screenshot), the merged parameters still include `stream: true`, so the request is handled as **streaming** even though the client did not send `stream` explicitly. + +![LLM log for the same request showing Type: Chat Stream](../../media/prompt-plugin-llm-log.png) + +In **Logs**, that run shows **Type: Chat Stream** and the full conversation: the committed **system** template, your **user** message from the request body, and the assistant reply. + +The provider receives the **stored** messages from the prompt version, checks if the request is streaming or non-streaming, applies the additional model parameters from the request and prepends the messages from the prompt version followed by your user message. + +--- + +## Example: Responses API + +```bash +curl -X POST http://localhost:8080/v1/responses \ + -H "Content-Type: application/json" \ + -H "bf-prompt-id: YOUR-PROMPT-UUID" \ + -H "bf-prompt-version: 4" \ + -H "x-bf-vk: sk-bf-your-virtual-key" \ + -d '{ + "model": "openai/gpt-5-nano-2025-08-07", + "input": "What is Pale Blue Dot?" + }' +``` + +--- + +## Streaming + +If the committed version’s **model parameters** include `"stream": true`, the plugin may set streaming on the HTTP transport so behavior matches the saved version. Client-side `stream` flags still interact with the merged parameters as usual. + +--- + +## Cache and updates + +The plugin keeps an in-memory cache of prompts and versions (loaded with a small number of store queries at startup). When you create, update, or delete prompts or versions through the **gateway APIs**, the server **reloads** that cache so new commits are visible without a full process restart. + +--- + +## Go SDK and custom resolution + +For embedded Bifrost (Go SDK), register the plugin with `prompts.Init` and a **config store** that implements the prompt tables API. The default resolver reads the same logical keys from `BifrostContext`: + +- `prompts.PromptIDKey` (`bf-prompt-id`) +- `prompts.PromptVersionKey` (`bf-prompt-version`) + +Set them on the context you pass to `ChatCompletion` / `Responses` if you are not going through the HTTP transport hooks. + +For advanced routing (for example, choosing a prompt from governance metadata), implement `prompts.PromptResolver` in `plugins/prompts/main.go` and use **`prompts.InitWithResolver`**. + +--- + +## Related + +- [Playground](/features/prompt-repository/playground) β€” create folders, prompts, sessions, and committed versions. +- [Writing Go plugins](/plugins/writing-go-plugin) β€” plugin interfaces and lifecycle. +- Built-in plugin name in code: `prompts` (`github.com/maximhq/bifrost/plugins/prompts`). diff --git a/docs/mcp/connecting-to-servers.mdx b/docs/mcp/connecting-to-servers.mdx index d3fc4274cb..119b46f32f 100644 --- a/docs/mcp/connecting-to-servers.mdx +++ b/docs/mcp/connecting-to-servers.mdx @@ -450,6 +450,191 @@ Environment variables are: --- +## Forwarding Request Headers to MCP Servers + + +Header Forwarding is available in **v1.5.0-prerelease1 and above**. + + +By default, Bifrost does not forward incoming request headers to MCP servers during tool execution. The `allowed_extra_headers` field lets you define a per-client allowlist of headers that callers may inject at request time and have forwarded to that MCP server when tools are executed. + +This is separate from the static `headers` field used for authentication: + +| Field | Purpose | When sent | +|-------|---------|-----------| +| `headers` | Static auth credentials (API keys, tokens) | Always, on every tool call | +| `allowed_extra_headers` | Dynamic per-request headers from callers | Only when the caller provides them, and only if they match the allowlist | + +**Common use cases:** +- Forwarding a user's auth token to an MCP server that enforces per-user authorization +- Passing a tenant or org ID to a multi-tenant MCP server +- Propagating trace or correlation IDs for end-to-end observability + +### How It Works + +1. An incoming request carries one or more headers matching a client's `allowed_extra_headers` pattern +2. Bifrost captures those headers from the request (using the union of all clients' allowlists) +3. At tool execution time, each client **re-checks** the header against its own allowlist β€” so the same header can be forwarded to one MCP server but not another + + +Headers are matched case-insensitively. The only wildcard supported is a standalone `"*"` (allow all headers) β€” partial patterns like `x-tenant-*` are not supported. If `"*"` is used, it must be the only entry in the list. + + + + + +**Configure:** Navigate to **MCP Gateway**, open the configuration sheet for an HTTP or SSE client, and set the **Allowed Extra Headers** field: + + + Allowed Extra Headers configuration in the MCP client edit sheet + + +**Send headers:** Include the allowed headers in any inference request to the LLM gateway: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-user-token: eyJhbGci..." \ + -H "x-tenant-id: acme-corp" \ + -d '{ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "Look up my account details"}] + }' +``` + + + + +**Configure:** Include `allowed_extra_headers` when creating or updating a client: + +```bash +curl -X POST http://localhost:8080/api/mcp/client \ + -H "Content-Type: application/json" \ + -d '{ + "name": "my_api", + "connection_type": "http", + "connection_string": "https://mcp.example.com/mcp", + "auth_type": "headers", + "headers": { + "Authorization": "Bearer service-token" + }, + "allowed_extra_headers": ["x-user-token", "x-tenant-id", "x-request-id"], + "tools_to_execute": ["*"] + }' +``` + +**Send headers:** Include the allowed headers in any inference request: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-user-token: eyJhbGci..." \ + -H "x-tenant-id: acme-corp" \ + -d '{ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "Look up my account details"}] + }' +``` + + + + +**Configure:** + +```json +{ + "mcp": { + "client_configs": [ + { + "name": "my_api", + "connection_type": "http", + "connection_string": "https://mcp.example.com/mcp", + "auth_type": "headers", + "headers": { + "Authorization": "Bearer service-token" + }, + "allowed_extra_headers": ["x-user-token", "x-tenant-id", "x-request-id"], + "tools_to_execute": ["*"] + } + ] + } +} +``` + +**Send headers:** Include the allowed headers in any inference request: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-user-token: eyJhbGci..." \ + -H "x-tenant-id: acme-corp" \ + -d '{ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "Look up my account details"}] + }' +``` + + + + +**Configure** the client as above (Web UI, Management API, or config.json). + +**Send headers:** When an external MCP client (e.g., Claude Desktop, Cursor) connects to Bifrost's `/mcp` endpoint, include the allowed headers in that HTTP request. Bifrost forwards them during any tool call made within that session: + +```json +{ + "mcpServers": { + "bifrost": { + "url": "http://localhost:8080/mcp", + "headers": { + "x-user-token": "eyJhbGci...", + "x-tenant-id": "acme-corp" + } + } + } +} +``` + + +Header support in MCP client config varies by client. The above JSON format applies to clients that support custom headers (e.g., Claude Desktop, Cursor). Check your MCP client's documentation for the exact configuration syntax. + + + + + +**Configure:** + +```go +schemas.MCPClientConfig{ + Name: "my_api", + ConnectionType: schemas.MCPConnectionTypeHTTP, + ConnectionString: bifrost.Ptr("https://mcp.example.com/mcp"), + AuthType: schemas.MCPAuthTypeHeaders, + Headers: map[string]schemas.EnvVar{ + "Authorization": {Value: "Bearer service-token"}, + }, + AllowedExtraHeaders: schemas.WhiteList{"x-user-token", "x-tenant-id", "x-request-id"}, + ToolsToExecute: []string{"*"}, +} +``` + +**Send headers:** Set `BifrostContextKeyMCPExtraHeaders` on the context before calling `ChatCompletionRequest` or `ExecuteChatMCPTool`: + +```go +bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) +bifrostCtx.SetValue(schemas.BifrostContextKeyMCPExtraHeaders, map[string][]string{ + "x-user-token": {"eyJhbGci..."}, + "x-tenant-id": {"acme-corp"}, +}) + +response, err := client.ChatCompletionRequest(bifrostCtx, request) +``` + + + + +--- + ## Client State Management ### Connection States diff --git a/docs/media/prompt-plugin-llm-log.png b/docs/media/prompt-plugin-llm-log.png new file mode 100644 index 0000000000..8ea06dd438 Binary files /dev/null and b/docs/media/prompt-plugin-llm-log.png differ diff --git a/docs/media/prompt-plugin-version-commit.png b/docs/media/prompt-plugin-version-commit.png new file mode 100644 index 0000000000..3741956acd Binary files /dev/null and b/docs/media/prompt-plugin-version-commit.png differ diff --git a/docs/media/ui-mcp-allowed-extra-headers.png b/docs/media/ui-mcp-allowed-extra-headers.png new file mode 100644 index 0000000000..a1c162ee3a Binary files /dev/null and b/docs/media/ui-mcp-allowed-extra-headers.png differ diff --git a/docs/media/ui-mcp-logs.png b/docs/media/ui-mcp-logs.png new file mode 100644 index 0000000000..55e586dc68 Binary files /dev/null and b/docs/media/ui-mcp-logs.png differ diff --git a/docs/media/ui-mcp-pricing.png b/docs/media/ui-mcp-pricing.png new file mode 100644 index 0000000000..85baf3dca9 Binary files /dev/null and b/docs/media/ui-mcp-pricing.png differ diff --git a/docs/media/ui-mcp-tool-group.png b/docs/media/ui-mcp-tool-group.png new file mode 100644 index 0000000000..6076eaa52d Binary files /dev/null and b/docs/media/ui-mcp-tool-group.png differ diff --git a/docs/media/ui-mcp-vk-config.png b/docs/media/ui-mcp-vk-config.png new file mode 100644 index 0000000000..299c6b835c Binary files /dev/null and b/docs/media/ui-mcp-vk-config.png differ diff --git a/docs/media/ui-routing-tree.png b/docs/media/ui-routing-tree.png new file mode 100644 index 0000000000..fa48b9833b Binary files /dev/null and b/docs/media/ui-routing-tree.png differ diff --git a/docs/media/user-provisioning/okta-api-token-created.png b/docs/media/user-provisioning/okta-api-token-created.png new file mode 100644 index 0000000000..e442519f8f Binary files /dev/null and b/docs/media/user-provisioning/okta-api-token-created.png differ diff --git a/docs/media/user-provisioning/okta-create-token-form.png b/docs/media/user-provisioning/okta-create-token-form.png new file mode 100644 index 0000000000..2888d28da7 Binary files /dev/null and b/docs/media/user-provisioning/okta-create-token-form.png differ diff --git a/docs/media/user-provisioning/okta-tokens-screen.png b/docs/media/user-provisioning/okta-tokens-screen.png new file mode 100644 index 0000000000..6530a8a6d7 Binary files /dev/null and b/docs/media/user-provisioning/okta-tokens-screen.png differ diff --git a/docs/media/user-provisioning/zitadel-add-role.png b/docs/media/user-provisioning/zitadel-add-role.png deleted file mode 100644 index f00212f6d2..0000000000 Binary files a/docs/media/user-provisioning/zitadel-add-role.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-add-user-select-key.png b/docs/media/user-provisioning/zitadel-add-user-select-key.png deleted file mode 100644 index ba8d8e52a9..0000000000 Binary files a/docs/media/user-provisioning/zitadel-add-user-select-key.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-auth-method.png b/docs/media/user-provisioning/zitadel-create-app-auth-method.png deleted file mode 100644 index c27a6e5772..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-auth-method.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-namne.png b/docs/media/user-provisioning/zitadel-create-app-namne.png deleted file mode 100644 index 7e220ce193..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-namne.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-uri.png b/docs/media/user-provisioning/zitadel-create-app-uri.png deleted file mode 100644 index 8796e77ec5..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-uri.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app.png b/docs/media/user-provisioning/zitadel-create-app.png new file mode 100644 index 0000000000..0400316932 Binary files /dev/null and b/docs/media/user-provisioning/zitadel-create-app.png differ diff --git a/docs/media/user-provisioning/zitadel-role-assignemnt.png b/docs/media/user-provisioning/zitadel-role-assignemnt.png deleted file mode 100644 index 5f233eb436..0000000000 Binary files a/docs/media/user-provisioning/zitadel-role-assignemnt.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-select-project.png b/docs/media/user-provisioning/zitadel-select-project.png deleted file mode 100644 index 824c48dd83..0000000000 Binary files a/docs/media/user-provisioning/zitadel-select-project.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-token-config.png b/docs/media/user-provisioning/zitadel-token-config.png deleted file mode 100644 index 1354a62830..0000000000 Binary files a/docs/media/user-provisioning/zitadel-token-config.png and /dev/null differ diff --git a/docs/migration-guides/v1.5.0.mdx b/docs/migration-guides/v1.5.0.mdx index 2fb81c673b..de384bfeca 100644 --- a/docs/migration-guides/v1.5.0.mdx +++ b/docs/migration-guides/v1.5.0.mdx @@ -3,7 +3,7 @@ title: "Migrating to v1.5.0" description: "Breaking changes and migration instructions for the v1.5.0 release" --- -v1.5.0 introduces several breaking changes across provider key configuration, Virtual Key semantics, the REST API, and plugins. This page consolidates every breaking change with before/after examples and a migration checklist. +v1.5.0 introduces several breaking changes across provider key configuration, Virtual Key semantics, the Go SDK, and the REST API. This page consolidates every breaking change with before/after examples and a migration checklist. **Make a database backup before upgrading.** Automatic database migrations run on startup and are not revertible. A backup is the only way to restore a previous state if anything goes wrong. A database successfully migrated to v1.5.0 cannot be used to run v1.4.x. @@ -21,6 +21,7 @@ The following automatic migrations run on upgrade: - Virtual Key provider configs with `allowed_models: []` are converted to `allowed_models: ["*"]` - Virtual Keys with no `provider_configs` are backfilled with all currently configured providers (`allowed_models: ["*"]`, `key_ids: ["*"]`) - Virtual Keys with no `mcp_configs` are backfilled with all currently connected MCP clients (`tools_to_execute: ["*"]`) +- Per-provider `deployments` maps (Azure, Bedrock, Vertex, Replicate) are migrated into the unified `aliases` field **The automatic migration only protects your existing data.** Any new configuration created after upgrading β€” via `config.json` or the REST API β€” must follow the new semantics described below. @@ -275,28 +276,29 @@ Support for image editing via `/v1/images/edits` on Replicate is also being remo --- -## Breaking Change 9: Ollama and SGL Per-Key URL Configuration +## Breaking Change 9: Provider `deployments` Removed β€” Migrate to `aliases` -**Who is affected:** Anyone running Ollama or SGL providers configured with `network_config.base_url` and no keys. +The per-provider `deployments` map has been removed from `azure_key_config`, `vertex_key_config`, `bedrock_key_config`, and `replicate_key_config`. A single top-level `aliases` field on each key replaces all of them. Aliases work across all providers and map any model name to a provider-specific identifier (deployment name, inference profile ARN, fine-tuned model ID, etc.). -### What changed - -Ollama and SGL previously used a single provider-level `base_url` and required no API keys. In v1.5.0, both providers use a **per-key URL model** β€” each key must include an `ollama_key_config` (or `sgl_key_config`) with the server `url`. This enables load balancing across multiple Ollama/SGL instances. - -- `network_config.base_url` is no longer used at runtime for Ollama/SGL β€” the URL must be on each key -- Ollama/SGL keys now **require** `ollama_key_config.url` / `sgl_key_config.url` -- The key management endpoints (`POST/PUT/DELETE /api/providers/{provider}/keys`) are now enabled for Ollama and SGL +The database migration runs automatically on startup, migrating existing deployment data into `aliases`. Only `config.json` files need to be updated manually. - -If you are running Bifrost with a database, existing Ollama/SGL providers are automatically migrated on startup. For each provider that has a `network_config.base_url`, the migration creates a default key with that URL. Only `config.json` files need to be updated manually. - +### Azure **Before:** ```json { "providers": { - "ollama": { - "network_config": { "base_url": "http://localhost:11434" } + "azure": { + "keys": [{ + "value": "env.AZURE_API_KEY", + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "deployments": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + } + } + }] } } } @@ -306,6 +308,17 @@ If you are running Bifrost with a database, existing Ollama/SGL providers are au ```json { "providers": { + "azure": { + "keys": [{ + "value": "env.AZURE_API_KEY", + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT" + }, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + } + }] "ollama": { "keys": [ { @@ -320,16 +333,141 @@ If you are running Bifrost with a database, existing Ollama/SGL providers are au } ``` -The same pattern applies to SGL β€” replace `ollama_key_config` with `sgl_key_config`. Server URLs support the `env.` prefix for environment variables. +### Bedrock + +**Before:** +```json +{ + "bedrock_key_config": { + "region": "env.AWS_REGION", + "deployments": { + "claude-3-5-sonnet": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0" + } + } +} +``` + +**After:** +```json +{ + "bedrock_key_config": { + "region": "env.AWS_REGION" + }, + "aliases": { + "claude-3-5-sonnet": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-5-sonnet-20241022-v2:0" + } +} +``` + +### Vertex + +**Before:** +```json +{ + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", + "region": "env.VERTEX_REGION", + "auth_credentials": "env.VERTEX_AUTH_CREDENTIALS", + "deployments": { + "gemini-2.0-flash": "projects/my-project/locations/us-central1/endpoints/123456" + } + } +} +``` + +**After:** +```json +{ + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", + "region": "env.VERTEX_REGION", + "auth_credentials": "env.VERTEX_AUTH_CREDENTIALS" + }, + "aliases": { + "gemini-2.0-flash": "projects/my-project/locations/us-central1/endpoints/123456" + } +} +``` + +### Replicate + +The Replicate key config is also restructured. The `deployments` map is gone. A new boolean `use_deployments_endpoint` controls whether requests are routed through the [Deployments API](https://replicate.com/docs/reference/http#deployments.predictions.create) (private, fixed hardware) or the standard Models API. + +**Before:** +```json +{ + "replicate_key_config": { + "deployments": { + "my-model": "owner/model-name/version-hash" + } + } +} +``` + +**After:** +```json +{ + "replicate_key_config": { + "use_deployments_endpoint": true + }, + "aliases": { + "my-model": "owner/model-name" + } +} +``` + +| Old field | New field | Notes | +|---|---|---| +| `replicate_key_config.deployments` | Removed | Use top-level `aliases` | +| _(new)_ | `replicate_key_config.use_deployments_endpoint` | `bool`, default `false` | --- -## Breaking Change 10: Go SDK Changes +## Breaking Change 10: Go SDK β€” `ExtraFields` Model Fields Renamed -If you import Bifrost's Go packages directly: +`ModelRequested string` has been replaced by two fields on `BifrostResponseExtraFields` and `BifrostErrorExtraFields`. -- **`HuggingFaceKeyConfig` removed** from the `Key` struct β€” remove any references to `HuggingFaceKeyConfig` or `huggingface_key_config` -- **`providerRequiresKey()` signature changed** β€” the `providerKey` parameter was removed; it now only accepts `*CustomProviderConfig` +**Before:** +```go +model := response.ExtraFields.ModelRequested +``` + +**After:** +```go +// The alias the caller passed as "model" in the request +original := response.ExtraFields.OriginalModelRequested + +// The actual identifier sent to the provider API +// Equals OriginalModelRequested when no alias is configured +resolved := response.ExtraFields.ResolvedModelUsed +``` + +The same rename applies to `BifrostErrorExtraFields`. + +**JSON tag changes:** + +| Old | New | +|---|---| +| `"model_requested"` | `"original_model_requested"` + `"resolved_model_used"` | + +--- + +## Breaking Change 11: Go SDK β€” `StreamAccumulatorResult` Field Renamed + +`Model string` has been replaced by two fields on `StreamAccumulatorResult` (returned by tracer streaming accumulation methods). + +**Before:** +```go +result.Model +``` + +**After:** +```go +result.RequestedModel // original alias from the caller +result.ResolvedModel // actual model identifier used by the provider +``` --- @@ -361,7 +499,7 @@ Update any client code that processes `weight` to accept `null` in addition to n -Ensure no list mixes `"*"` with specific values and no list has duplicate entries. +Ensure no list mixes `"*"` with specific values (e.g., `["*", "gpt-4o"]`) and no list has duplicate entries. @@ -372,12 +510,16 @@ Stop sending `keys` in provider create/update payloads and stop reading `keys` f Replace `enable_litellm_fallbacks` with the appropriate combination of `convert_text_to_chat`, `convert_chat_to_responses`, and `should_drop_params`. - -Move `network_config.base_url` into per-key `ollama_key_config.url` / `sgl_key_config.url` in `config.json`. Database users are auto-migrated on startup, but `config.json` must be updated manually. + +Remove `deployments` from `azure_key_config`, `vertex_key_config`, `bedrock_key_config`, and `replicate_key_config`. Move those mappings to the top-level `aliases` field on each key. For Replicate, set `use_deployments_endpoint: true` if you were using the deployments endpoint. + + + +Replace `ExtraFields.ModelRequested` with `ExtraFields.OriginalModelRequested` (and optionally read `ExtraFields.ResolvedModelUsed`). Update JSON consumers reading `"model_requested"` to use `"original_model_requested"` and `"resolved_model_used"`. - -Remove any references to `HuggingFaceKeyConfig` / `huggingface_key_config` from the `Key` struct. Update any direct calls to `providerRequiresKey()` β€” the `providerKey` parameter has been removed. + +Replace `.Model` with `.RequestedModel` (and optionally `.ResolvedModel`) on any `StreamAccumulatorResult` usage. @@ -387,7 +529,7 @@ Remove any references to `HuggingFaceKeyConfig` / `huggingface_key_config` from **All requests returning 403/blocked after upgrade** -A provider key has `models: []`, a Virtual Key has no `provider_configs`, or a provider config has `allowed_models: []`. Check Bifrost logs β€” a blocked request logs which rule denied it. Fix: add `"models": ["*"]` on provider keys and `"allowed_models": ["*"]` on VK provider configs. +A provider key has `models: []`, a Virtual Key has no `provider_configs`, or a provider config has `allowed_models: []`. Check Bifrost logs β€” a blocked request logs which rule denied it. Fix: add `"models": ["*"]` on provider keys, `"allowed_models": ["*"]` on VK provider configs. **MCP tools not being injected / tool calls blocked** @@ -399,12 +541,16 @@ A whitelist validation failure β€” either mixing `"*"` with specific values, or **"No keys available" or key selection errors** -A provider config with `key_ids` omitted or `[]` now blocks all keys. Add `"key_ids": ["*"]`. +A provider config with `key_ids` omitted or `[]` now blocks all keys (`allow_all_keys: false`). Add `"key_ids": ["*"]`. **Provider create/update errors about `keys` field** The `keys` field has been removed. Remove it from provider payloads and use `/api/providers/{provider}/keys` instead. -**Ollama/SGL requests fail with "no base URL configured"** +**Replicate requests failing after upgrade** + +If you used `replicate_key_config.deployments`, move the mappings to the top-level `aliases` field and set `use_deployments_endpoint: true` if you were targeting the Deployments API. + +**Go SDK compilation errors on `ModelRequested` or `StreamAccumulatorResult.Model`** -Ollama/SGL keys must have `ollama_key_config.url` / `sgl_key_config.url` set. Database users are auto-migrated. For `config.json`, add the key config manually β€” see Breaking Change 9 for examples. +Rename to `OriginalModelRequested`/`ResolvedModelUsed` on ExtraFields, and `RequestedModel`/`ResolvedModel` on StreamAccumulatorResult. diff --git a/docs/openapi/openapi.json b/docs/openapi/openapi.json index 3c517e02f3..1043039a1f 100644 --- a/docs/openapi/openapi.json +++ b/docs/openapi/openapi.json @@ -134543,61 +134543,133 @@ } } }, - "/api/session/login": { - "post": { - "operationId": "login", - "summary": "Login", - "description": "Authenticates a user and returns a session token.\nSets a cookie with the session token for subsequent requests.\n", + "/api/users": { + "get": { + "operationId": "listUsers", + "summary": "List users", + "description": "Returns a paginated list of users with optional search.", "tags": [ - "Session" + "Users" ], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "type": "object", - "description": "Login request", - "required": [ - "username", - "password" - ], - "properties": { - "username": { - "type": "string" - }, - "password": { - "type": "string" - } - } - } + "parameters": [ + { + "name": "page", + "in": "query", + "description": "Page number (1-based)", + "schema": { + "type": "integer", + "minimum": 1, + "default": 1 + } + }, + { + "name": "limit", + "in": "query", + "description": "Number of users per page (max 100)", + "schema": { + "type": "integer", + "minimum": 1, + "maximum": 100, + "default": 20 + } + }, + { + "name": "search", + "in": "query", + "description": "Search by name or email", + "schema": { + "type": "string" } } - }, + ], "responses": { "200": { - "description": "Login successful", + "description": "Successful response", "content": { "application/json": { "schema": { "type": "object", - "description": "Login response", "properties": { - "message": { - "type": "string", - "example": "Login successful" + "users": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique user identifier" + }, + "name": { + "type": "string", + "description": "User's display name" + }, + "email": { + "type": "string", + "format": "email", + "description": "User's email address" + }, + "role_id": { + "type": "integer", + "nullable": true, + "description": "ID of the assigned RBAC role" + }, + "role": { + "type": "object", + "nullable": true, + "description": "RBAC role details", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "is_system_role": { + "type": "boolean" + } + } + }, + "created_at": { + "type": "string", + "format": "date-time" + }, + "updated_at": { + "type": "string", + "format": "date-time" + } + } + } }, - "token": { - "type": "string", - "description": "Session token" + "total": { + "type": "integer", + "description": "Total number of users matching the query" + }, + "page": { + "type": "integer", + "description": "Current page number" + }, + "limit": { + "type": "integer", + "description": "Number of users per page" + }, + "total_pages": { + "type": "integer", + "description": "Total number of pages" + }, + "has_more": { + "type": "boolean", + "description": "Whether more pages are available" } } } } } }, - "400": { - "description": "Bad request", + "500": { + "description": "Internal server error", "content": { "application/json": { "schema": { @@ -134680,9 +134752,110 @@ } } } + } + } + }, + "post": { + "operationId": "createUser", + "summary": "Create user", + "description": "Manually creates a new user in the organization.", + "tags": [ + "Users" + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "name", + "email" + ], + "properties": { + "name": { + "type": "string", + "description": "User's display name" + }, + "email": { + "type": "string", + "format": "email", + "description": "User's email address (must be unique)" + }, + "role_id": { + "type": "integer", + "description": "Optional RBAC role ID to assign" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "User created successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique user identifier" + }, + "name": { + "type": "string", + "description": "User's display name" + }, + "email": { + "type": "string", + "format": "email", + "description": "User's email address" + }, + "role_id": { + "type": "integer", + "nullable": true, + "description": "ID of the assigned RBAC role" + }, + "role": { + "type": "object", + "nullable": true, + "description": "RBAC role details", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "is_system_role": { + "type": "boolean" + } + } + }, + "created_at": { + "type": "string", + "format": "date-time" + }, + "updated_at": { + "type": "string", + "format": "date-time" + } + } + } + } + } + } + } }, - "401": { - "description": "Invalid credentials", + "400": { + "description": "Bad request", "content": { "application/json": { "schema": { @@ -134766,13 +134939,13 @@ } } }, - "403": { - "description": "Authentication is not enabled", + "409": { + "description": "User with this email already exists", "content": { "application/json": { "schema": { "type": "object", - "description": "Error response from Bifrost", + "description": "Error response", "properties": { "event_id": { "type": "string" @@ -134939,39 +135112,49 @@ } } }, - "/api/session/logout": { - "post": { - "operationId": "logout", - "summary": "Logout", - "description": "Logs out the current user and invalidates the session token.", + "/api/users/{id}": { + "delete": { + "operationId": "deleteUser", + "summary": "Delete user", + "description": "Permanently removes a user from the organization. This cascades to delete the user's governance settings (budget/rate limits), team memberships, access profiles, and OIDC sessions. Cannot delete yourself.\n", "tags": [ - "Session" + "Users" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "User ID", + "schema": { + "type": "string" + } + } ], "responses": { "200": { - "description": "Logout successful", + "description": "User deleted successfully", "content": { "application/json": { "schema": { "type": "object", - "description": "Logout response", + "description": "Simple message response", "properties": { "message": { - "type": "string", - "example": "Logout successful" + "type": "string" } } } } } }, - "403": { - "description": "Authentication is not enabled", + "400": { + "description": "Bad request (e.g. cannot delete yourself)", "content": { "application/json": { "schema": { "type": "object", - "description": "Error response from Bifrost", + "description": "Error response", "properties": { "event_id": { "type": "string" @@ -135049,45 +135232,4020 @@ } } } - } - } - } - }, - "/api/session/is-auth-enabled": { - "get": { - "operationId": "isAuthEnabled", - "summary": "Check if authentication is enabled", - "description": "Returns whether authentication is enabled and if the current token is valid.", - "tags": [ - "Session" - ], - "responses": { - "200": { - "description": "Successful response", - "content": { - "application/json": { - "schema": { - "type": "object", - "description": "Auth enabled status response", - "properties": { - "is_auth_enabled": { - "type": "boolean" - }, - "has_valid_token": { - "type": "boolean" - } - } - } - } - } }, - "500": { - "description": "Internal server error", + "404": { + "description": "User not found", "content": { "application/json": { "schema": { "type": "object", - "description": "Error response from Bifrost", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/users/me/permissions": { + "get": { + "operationId": "getCurrentUserPermissions", + "summary": "Get current user permissions", + "description": "Returns the RBAC permissions for the authenticated user. When SCIM is not enabled, returns full permissions for all resources. Otherwise returns the permissions associated with the user's assigned role.\n", + "tags": [ + "Users" + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "permissions": { + "type": "object", + "description": "Map of resource names to their permitted operations. When SCIM is disabled, returns full permissions for all resources.\n", + "additionalProperties": { + "type": "object", + "additionalProperties": { + "type": "boolean" + } + } + } + } + } + } + } + }, + "401": { + "description": "Unauthorized (user not authenticated)", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "User not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/users/{id}/role": { + "put": { + "operationId": "assignUserRole", + "summary": "Assign role to user", + "description": "Assigns an RBAC role to a user. This also auto-assigns the default access profile for the new role and reloads the RBAC permission cache.\n", + "tags": [ + "Users" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "User ID", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "role_id" + ], + "properties": { + "role_id": { + "type": "integer", + "description": "ID of the RBAC role to assign" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Role assigned successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Simple message response", + "properties": { + "message": { + "type": "string" + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "User or role not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/users/{id}/teams": { + "get": { + "operationId": "getUserTeams", + "summary": "Get user's teams", + "description": "Returns the list of teams a user belongs to, including the membership source.", + "tags": [ + "Users" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "User ID", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "teams": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Team ID" + }, + "name": { + "type": "string", + "description": "Team name" + }, + "source": { + "type": "string", + "description": "How the user was added to this team (e.g. \"manual\", \"scim_sync\")" + } + } + } + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "User not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + }, + "put": { + "operationId": "updateUserTeams", + "summary": "Update user's team assignments", + "description": "Replaces the user's manual team assignments. Synced team memberships (from SCIM providers) are preserved and cannot be removed via this endpoint.\n", + "tags": [ + "Users" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "User ID", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "team_ids" + ], + "properties": { + "team_ids": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of team IDs to assign (replaces existing manual assignments; synced memberships are preserved)" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Teams updated successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Simple message response", + "properties": { + "message": { + "type": "string" + } + } + } + } + } + }, + "400": { + "description": "Bad request (e.g. team not found)", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "User not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/teams": { + "get": { + "operationId": "listTeams", + "summary": "List teams", + "description": "Returns a paginated list of teams with optional search.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "page", + "in": "query", + "description": "Page number (1-based)", + "schema": { + "type": "integer", + "minimum": 1, + "default": 1 + } + }, + { + "name": "limit", + "in": "query", + "description": "Number of teams per page (max 100)", + "schema": { + "type": "integer", + "minimum": 1, + "maximum": 100, + "default": 20 + } + }, + { + "name": "search", + "in": "query", + "description": "Search by team name", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "teams": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Team ID (derived from name)" + }, + "name": { + "type": "string", + "description": "Team name" + }, + "member_count": { + "type": "integer", + "description": "Number of members in the team" + }, + "created_at": { + "type": "string", + "format": "date-time" + }, + "updated_at": { + "type": "string", + "format": "date-time" + } + } + } + }, + "total": { + "type": "integer" + }, + "page": { + "type": "integer" + }, + "limit": { + "type": "integer" + }, + "total_pages": { + "type": "integer", + "description": "Total number of pages" + }, + "has_more": { + "type": "boolean", + "description": "Whether more pages are available" + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + }, + "post": { + "operationId": "createTeam", + "summary": "Create team", + "description": "Creates a new team. The team ID is derived from the name.", + "tags": [ + "Teams" + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "name" + ], + "properties": { + "name": { + "type": "string", + "description": "Team name (must be unique)" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Team created successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "409": { + "description": "Team with this name already exists", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/teams/{id}": { + "get": { + "operationId": "getTeam", + "summary": "Get team", + "description": "Returns details of a specific team including member count.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "Team ID", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Team ID (derived from name)" + }, + "name": { + "type": "string", + "description": "Team name" + }, + "member_count": { + "type": "integer", + "description": "Number of members in the team" + }, + "created_at": { + "type": "string", + "format": "date-time" + }, + "updated_at": { + "type": "string", + "format": "date-time" + } + } + } + } + } + }, + "404": { + "description": "Team not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + }, + "put": { + "operationId": "updateTeam", + "summary": "Update team", + "description": "Updates a team. Note that renaming teams is not allowed.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "Team ID", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Updated team description" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Team updated successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + } + } + } + } + } + }, + "400": { + "description": "Bad request (e.g. renaming not allowed)", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "Team not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + }, + "delete": { + "operationId": "deleteTeam", + "summary": "Delete team", + "description": "Permanently removes a team.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "Team ID", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Team deleted successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Simple message response", + "properties": { + "message": { + "type": "string" + } + } + } + } + } + }, + "404": { + "description": "Team not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/teams/{id}/members": { + "get": { + "operationId": "getTeamMembers", + "summary": "List team members", + "description": "Returns all members of a team with their user details and membership source.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "Team ID", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "members": { + "type": "array", + "items": { + "type": "object", + "properties": { + "user_id": { + "type": "string" + }, + "user_name": { + "type": "string" + }, + "user_email": { + "type": "string" + }, + "source": { + "type": "string", + "description": "How the member was added (e.g. \"manual\", \"scim_sync\")" + } + } + } + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "404": { + "description": "Team not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + }, + "post": { + "operationId": "addTeamMember", + "summary": "Add team member", + "description": "Adds a user to a team. Both the team and user must exist.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "Team ID", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": [ + "user_id" + ], + "properties": { + "user_id": { + "type": "string", + "description": "ID of the user to add to the team" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Member added successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Simple message response", + "properties": { + "message": { + "type": "string" + } + } + } + } + } + }, + "404": { + "description": "Team or user not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "409": { + "description": "User is already a member of this team", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/teams/{id}/members/{userId}": { + "delete": { + "operationId": "removeTeamMember", + "summary": "Remove team member", + "description": "Removes a user from a team.", + "tags": [ + "Teams" + ], + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "Team ID", + "schema": { + "type": "string" + } + }, + { + "name": "userId", + "in": "path", + "required": true, + "description": "User ID to remove", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Member removed successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Simple message response", + "properties": { + "message": { + "type": "string" + } + } + } + } + } + }, + "404": { + "description": "Membership not found", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/session/login": { + "post": { + "operationId": "login", + "summary": "Login", + "description": "Authenticates a user and returns a session token.\nSets a cookie with the session token for subsequent requests.\n", + "tags": [ + "Session" + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Login request", + "required": [ + "username", + "password" + ], + "properties": { + "username": { + "type": "string" + }, + "password": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Login successful", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Login response", + "properties": { + "message": { + "type": "string", + "example": "Login successful" + }, + "token": { + "type": "string", + "description": "Session token" + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "401": { + "description": "Invalid credentials", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "403": { + "description": "Authentication is not enabled", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/session/logout": { + "post": { + "operationId": "logout", + "summary": "Logout", + "description": "Logs out the current user and invalidates the session token.", + "tags": [ + "Session" + ], + "responses": { + "200": { + "description": "Logout successful", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Logout response", + "properties": { + "message": { + "type": "string", + "example": "Logout successful" + } + } + } + } + } + }, + "403": { + "description": "Authentication is not enabled", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", + "properties": { + "event_id": { + "type": "string" + }, + "type": { + "type": "string" + }, + "is_bifrost_error": { + "type": "boolean" + }, + "status_code": { + "type": "integer" + }, + "error": { + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "code": { + "type": "string" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + }, + "event_id": { + "type": "string" + } + } + }, + "extra_fields": { + "type": "object", + "properties": { + "provider": { + "type": "string", + "description": "AI model provider identifier", + "enum": [ + "openai", + "azure", + "anthropic", + "bedrock", + "cohere", + "vertex", + "vllm", + "mistral", + "ollama", + "groq", + "sgl", + "parasail", + "perplexity", + "replicate", + "cerebras", + "gemini", + "openrouter", + "elevenlabs", + "huggingface", + "nebius", + "xai", + "runway", + "fireworks" + ] + }, + "model_requested": { + "type": "string" + }, + "request_type": { + "type": "string" + } + } + } + } + } + } + } + } + } + } + }, + "/api/session/is-auth-enabled": { + "get": { + "operationId": "isAuthEnabled", + "summary": "Check if authentication is enabled", + "description": "Returns whether authentication is enabled and if the current token is valid.", + "tags": [ + "Session" + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Auth enabled status response", + "properties": { + "is_auth_enabled": { + "type": "boolean" + }, + "has_valid_token": { + "type": "boolean" + } + } + } + } + } + }, + "500": { + "description": "Internal server error", + "content": { + "application/json": { + "schema": { + "type": "object", + "description": "Error response from Bifrost", "properties": { "event_id": { "type": "string" @@ -138460,6 +142618,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -138479,12 +142648,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -138617,12 +142780,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -138705,12 +142862,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -138735,18 +142886,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -138822,6 +142961,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -139177,6 +143326,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -139196,12 +143356,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -139334,12 +143488,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -139422,12 +143570,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -139452,18 +143594,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -139539,6 +143669,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -139614,6 +143754,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -139633,12 +143784,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -139771,12 +143916,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -139859,12 +143998,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -139889,18 +144022,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -139976,6 +144097,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -140422,6 +144553,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -140441,12 +144583,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -140579,12 +144715,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -140667,12 +144797,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -140697,18 +144821,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -140784,6 +144896,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -141142,6 +145264,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -141161,12 +145294,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -141299,12 +145426,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -141387,12 +145508,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -141417,18 +145532,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -141504,6 +145607,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -141579,6 +145692,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -141598,12 +145722,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -141736,12 +145854,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -141824,12 +145936,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -141854,18 +145960,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -141941,6 +146035,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -142300,6 +146404,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -142319,12 +146434,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -142457,12 +146566,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -142545,12 +146648,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -142575,18 +146672,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -142662,6 +146747,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -143005,6 +147100,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -143024,12 +147130,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -143162,12 +147262,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -143250,12 +147344,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -143280,18 +147368,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -143367,6 +147443,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -146867,7 +150953,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, @@ -147115,7 +151202,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, @@ -147244,7 +151332,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, @@ -147373,7 +151462,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, @@ -147773,7 +151863,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, @@ -176306,6 +180397,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -176325,12 +180427,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -176463,12 +180559,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -176551,12 +180641,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -176581,18 +180665,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -176668,6 +180740,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -203143,6 +207225,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -203162,12 +207255,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -203300,12 +207387,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -203388,12 +207469,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -203418,18 +207493,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -203505,6 +207568,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -203578,6 +207651,17 @@ "type": "number", "description": "Weight for load balancing" }, + "aliases": { + "type": "object", + "propertyNames": { + "minLength": 1 + }, + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "description": "Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" + }, "azure_key_config": { "type": "object", "description": "Azure-specific key configuration", @@ -203597,12 +207681,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "api_version": { "type": "object", "description": "Environment variable configuration", @@ -203735,12 +207813,6 @@ "type": "boolean" } } - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } } } }, @@ -203823,12 +207895,6 @@ } } }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - }, "batch_s3_config": { "type": "object", "properties": { @@ -203853,18 +207919,6 @@ } } }, - "replicate_key_config": { - "type": "object", - "description": "Replicate-specific key configuration", - "properties": { - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - } - } - } - }, "vllm_key_config": { "type": "object", "description": "VLLM-specific key configuration", @@ -203940,6 +207994,16 @@ "url" ] }, + "replicate_key_config": { + "type": "object", + "description": "Replicate-specific key configuration", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint" + } + } + }, "enabled": { "type": "boolean", "description": "Whether the key is active (defaults to true)" @@ -204376,7 +208440,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, @@ -204537,7 +208602,8 @@ "enum": [ "none", "headers", - "oauth" + "oauth", + "per_user_oauth" ], "description": "Authentication type for the MCP connection" }, diff --git a/docs/openapi/openapi.yaml b/docs/openapi/openapi.yaml index cc125b63a7..0e17c36c92 100644 --- a/docs/openapi/openapi.yaml +++ b/docs/openapi/openapi.yaml @@ -548,6 +548,28 @@ paths: /api/pricing/force-sync: $ref: './paths/management/config.yaml#/force-sync-pricing' + # Users + /api/users: + $ref: './paths/management/users.yaml#/users' + /api/users/{id}: + $ref: './paths/management/users.yaml#/users-by-id' + /api/users/me/permissions: + $ref: './paths/management/users.yaml#/users-me-permissions' + /api/users/{id}/role: + $ref: './paths/management/users.yaml#/users-role' + /api/users/{id}/teams: + $ref: './paths/management/users.yaml#/users-teams' + + # Teams + /api/teams: + $ref: './paths/management/users.yaml#/teams' + /api/teams/{id}: + $ref: './paths/management/users.yaml#/teams-by-id' + /api/teams/{id}/members: + $ref: './paths/management/users.yaml#/team-members' + /api/teams/{id}/members/{userId}: + $ref: './paths/management/users.yaml#/team-member-by-id' + # Session /api/session/login: $ref: './paths/management/session.yaml#/login' diff --git a/docs/openapi/paths/management/users.yaml b/docs/openapi/paths/management/users.yaml new file mode 100644 index 0000000000..5de0e36fdd --- /dev/null +++ b/docs/openapi/paths/management/users.yaml @@ -0,0 +1,534 @@ +users: + get: + operationId: listUsers + summary: List users + description: Returns a paginated list of users with optional search. + tags: + - Users + parameters: + - name: page + in: query + description: Page number (1-based) + schema: + type: integer + minimum: 1 + default: 1 + - name: limit + in: query + description: Number of users per page (max 100) + schema: + type: integer + minimum: 1 + maximum: 100 + default: 20 + - name: search + in: query + description: Search by name or email + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/ListUsersResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + + post: + operationId: createUser + summary: Create user + description: Manually creates a new user in the organization. + tags: + - Users + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/CreateUserRequest' + responses: + '200': + description: User created successfully + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/UserResponse' + '400': + $ref: '../../openapi.yaml#/components/responses/BadRequest' + '409': + description: User with this email already exists + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +users-by-id: + delete: + operationId: deleteUser + summary: Delete user + description: > + Permanently removes a user from the organization. This cascades to delete + the user's governance settings (budget/rate limits), team memberships, + access profiles, and OIDC sessions. Cannot delete yourself. + tags: + - Users + parameters: + - name: id + in: path + required: true + description: User ID + schema: + type: string + responses: + '200': + description: User deleted successfully + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/MessageResponse' + '400': + description: Bad request (e.g. cannot delete yourself) + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '404': + description: User not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +users-me-permissions: + get: + operationId: getCurrentUserPermissions + summary: Get current user permissions + description: > + Returns the RBAC permissions for the authenticated user. When SCIM is not + enabled, returns full permissions for all resources. Otherwise returns the + permissions associated with the user's assigned role. + tags: + - Users + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/PermissionsResponse' + '401': + description: Unauthorized (user not authenticated) + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '404': + description: User not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +users-role: + put: + operationId: assignUserRole + summary: Assign role to user + description: > + Assigns an RBAC role to a user. This also auto-assigns the default + access profile for the new role and reloads the RBAC permission cache. + tags: + - Users + parameters: + - name: id + in: path + required: true + description: User ID + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/AssignUserRoleRequest' + responses: + '200': + description: Role assigned successfully + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/MessageResponse' + '400': + $ref: '../../openapi.yaml#/components/responses/BadRequest' + '404': + description: User or role not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +users-teams: + get: + operationId: getUserTeams + summary: Get user's teams + description: Returns the list of teams a user belongs to, including the membership source. + tags: + - Users + parameters: + - name: id + in: path + required: true + description: User ID + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/UserTeamsResponse' + '400': + $ref: '../../openapi.yaml#/components/responses/BadRequest' + '404': + description: User not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + + put: + operationId: updateUserTeams + summary: Update user's team assignments + description: > + Replaces the user's manual team assignments. Synced team memberships + (from SCIM providers) are preserved and cannot be removed via this endpoint. + tags: + - Users + parameters: + - name: id + in: path + required: true + description: User ID + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/UpdateUserTeamsRequest' + responses: + '200': + description: Teams updated successfully + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/MessageResponse' + '400': + description: Bad request (e.g. team not found) + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '404': + description: User not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +# ---- Teams ---- + +teams: + get: + operationId: listTeams + summary: List teams + description: Returns a paginated list of teams with optional search. + tags: + - Teams + parameters: + - name: page + in: query + description: Page number (1-based) + schema: + type: integer + minimum: 1 + default: 1 + - name: limit + in: query + description: Number of teams per page (max 100) + schema: + type: integer + minimum: 1 + maximum: 100 + default: 20 + - name: search + in: query + description: Search by team name + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/ListTeamsResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + + post: + operationId: createTeam + summary: Create team + description: Creates a new team. The team ID is derived from the name. + tags: + - Teams + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/CreateTeamRequest' + responses: + '200': + description: Team created successfully + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/CreateTeamResponse' + '400': + $ref: '../../openapi.yaml#/components/responses/BadRequest' + '409': + description: Team with this name already exists + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +teams-by-id: + get: + operationId: getTeam + summary: Get team + description: Returns details of a specific team including member count. + tags: + - Teams + parameters: + - name: id + in: path + required: true + description: Team ID + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/TeamObject' + '404': + description: Team not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + + put: + operationId: updateTeam + summary: Update team + description: Updates a team. Note that renaming teams is not allowed. + tags: + - Teams + parameters: + - name: id + in: path + required: true + description: Team ID + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/UpdateTeamRequest' + responses: + '200': + description: Team updated successfully + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/CreateTeamResponse' + '400': + description: Bad request (e.g. renaming not allowed) + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '404': + description: Team not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + + delete: + operationId: deleteTeam + summary: Delete team + description: Permanently removes a team. + tags: + - Teams + parameters: + - name: id + in: path + required: true + description: Team ID + schema: + type: string + responses: + '200': + description: Team deleted successfully + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/MessageResponse' + '404': + description: Team not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +# ---- Team Members ---- + +team-members: + get: + operationId: getTeamMembers + summary: List team members + description: Returns all members of a team with their user details and membership source. + tags: + - Teams + parameters: + - name: id + in: path + required: true + description: Team ID + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/TeamMembersResponse' + '400': + $ref: '../../openapi.yaml#/components/responses/BadRequest' + '404': + description: Team not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + + post: + operationId: addTeamMember + summary: Add team member + description: Adds a user to a team. Both the team and user must exist. + tags: + - Teams + parameters: + - name: id + in: path + required: true + description: Team ID + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '../../schemas/management/users.yaml#/AddTeamMemberRequest' + responses: + '200': + description: Member added successfully + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/MessageResponse' + '404': + description: Team or user not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '409': + description: User is already a member of this team + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' + +team-member-by-id: + delete: + operationId: removeTeamMember + summary: Remove team member + description: Removes a user from a team. + tags: + - Teams + parameters: + - name: id + in: path + required: true + description: Team ID + schema: + type: string + - name: userId + in: path + required: true + description: User ID to remove + schema: + type: string + responses: + '200': + description: Member removed successfully + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/MessageResponse' + '404': + description: Membership not found + content: + application/json: + schema: + $ref: '../../schemas/management/common.yaml#/ErrorResponse' + '500': + $ref: '../../openapi.yaml#/components/responses/InternalError' diff --git a/docs/openapi/schemas/management/mcp.yaml b/docs/openapi/schemas/management/mcp.yaml index 893885ca36..ab514620b6 100644 --- a/docs/openapi/schemas/management/mcp.yaml +++ b/docs/openapi/schemas/management/mcp.yaml @@ -2,12 +2,13 @@ MCPAuthType: type: string - enum: [none, headers, oauth] + enum: [none, headers, oauth, per_user_oauth] description: | Authentication type for MCP connections: - none: No authentication - headers: Header-based authentication (API keys, custom headers, etc.) - - oauth: OAuth 2.0 authentication + - oauth: OAuth 2.0 authentication (server-level, admin authenticates once) + - per_user_oauth: Per-user OAuth 2.0 authentication (each user authenticates individually) MCPConnectionType: type: string diff --git a/docs/openapi/schemas/management/providers.yaml b/docs/openapi/schemas/management/providers.yaml index b5e3a07b41..435b6f7feb 100644 --- a/docs/openapi/schemas/management/providers.yaml +++ b/docs/openapi/schemas/management/providers.yaml @@ -71,10 +71,6 @@ AzureKeyConfig: properties: endpoint: $ref: '../../schemas/management/common.yaml#/EnvVar' - deployments: - type: object - additionalProperties: - type: string api_version: $ref: '../../schemas/management/common.yaml#/EnvVar' client_id: @@ -101,10 +97,6 @@ VertexKeyConfig: $ref: '../../schemas/management/common.yaml#/EnvVar' auth_credentials: $ref: '../../schemas/management/common.yaml#/EnvVar' - deployments: - type: object - additionalProperties: - type: string BedrockKeyConfig: type: object @@ -120,10 +112,6 @@ BedrockKeyConfig: $ref: '../../schemas/management/common.yaml#/EnvVar' arn: $ref: '../../schemas/management/common.yaml#/EnvVar' - deployments: - type: object - additionalProperties: - type: string batch_s3_config: type: object properties: @@ -159,6 +147,14 @@ OllamaKeyConfig: required: - url +ReplicateKeyConfig: + type: object + description: Replicate-specific key configuration + properties: + use_deployments_endpoint: + type: boolean + description: Whether to use the deployments endpoint instead of the models endpoint + SglKeyConfig: type: object description: SGLang-specific key configuration @@ -168,14 +164,16 @@ SglKeyConfig: required: - url -ReplicateKeyConfig: +VLLMKeyConfig: type: object - description: Replicate-specific key configuration + description: vLLM-specific key configuration for per-key routing to different vLLM instances properties: - deployments: - type: object - additionalProperties: - type: string + url: + $ref: '../../schemas/management/common.yaml#/EnvVar' + description: vLLM server base URL (required) + model_name: + type: string + description: Exact model name served on this vLLM instance VLLMKeyConfig: type: object @@ -214,20 +212,28 @@ Key: weight: type: number description: Weight for load balancing + aliases: + type: object + propertyNames: + minLength: 1 + additionalProperties: + type: string + minLength: 1 + description: Model alias mappings β€” maps a user-facing model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.) azure_key_config: $ref: '#/AzureKeyConfig' vertex_key_config: $ref: '#/VertexKeyConfig' bedrock_key_config: $ref: '#/BedrockKeyConfig' - replicate_key_config: - $ref: '#/ReplicateKeyConfig' vllm_key_config: $ref: '#/VllmKeyConfig' ollama_key_config: $ref: '#/OllamaKeyConfig' sgl_key_config: $ref: '#/SglKeyConfig' + replicate_key_config: + $ref: '#/ReplicateKeyConfig' enabled: type: boolean description: Whether the key is active (defaults to true) diff --git a/docs/openapi/schemas/management/users.yaml b/docs/openapi/schemas/management/users.yaml new file mode 100644 index 0000000000..46db148f3e --- /dev/null +++ b/docs/openapi/schemas/management/users.yaml @@ -0,0 +1,239 @@ +UserObject: + type: object + properties: + id: + type: string + description: Unique user identifier + name: + type: string + description: User's display name + email: + type: string + format: email + description: User's email address + role_id: + type: integer + nullable: true + description: ID of the assigned RBAC role + role: + type: object + nullable: true + description: RBAC role details + properties: + id: + type: integer + name: + type: string + description: + type: string + is_system_role: + type: boolean + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + +CreateUserRequest: + type: object + required: + - name + - email + properties: + name: + type: string + description: User's display name + email: + type: string + format: email + description: User's email address (must be unique) + role_id: + type: integer + description: Optional RBAC role ID to assign + +UserResponse: + type: object + properties: + user: + $ref: '#/UserObject' + +ListUsersResponse: + type: object + properties: + users: + type: array + items: + $ref: '#/UserObject' + total: + type: integer + description: Total number of users matching the query + page: + type: integer + description: Current page number + limit: + type: integer + description: Number of users per page + total_pages: + type: integer + description: Total number of pages + has_more: + type: boolean + description: Whether more pages are available + +# ---- User Permissions ---- + +PermissionsResponse: + type: object + properties: + permissions: + type: object + description: > + Map of resource names to their permitted operations. + When SCIM is disabled, returns full permissions for all resources. + additionalProperties: + type: object + additionalProperties: + type: boolean + +# ---- User Role ---- + +AssignUserRoleRequest: + type: object + required: + - role_id + properties: + role_id: + type: integer + description: ID of the RBAC role to assign + +# ---- User Teams ---- + +UserTeamEntry: + type: object + properties: + id: + type: string + description: Team ID + name: + type: string + description: Team name + source: + type: string + description: How the user was added to this team (e.g. "manual", "scim_sync") + +UserTeamsResponse: + type: object + properties: + teams: + type: array + items: + $ref: '#/UserTeamEntry' + +UpdateUserTeamsRequest: + type: object + required: + - team_ids + properties: + team_ids: + type: array + items: + type: string + description: List of team IDs to assign (replaces existing manual assignments; synced memberships are preserved) + +# ---- Teams ---- + +TeamObject: + type: object + properties: + id: + type: string + description: Team ID (derived from name) + name: + type: string + description: Team name + member_count: + type: integer + description: Number of members in the team + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + +CreateTeamRequest: + type: object + required: + - name + properties: + name: + type: string + description: Team name (must be unique) + +UpdateTeamRequest: + type: object + properties: + description: + type: string + description: Updated team description + +CreateTeamResponse: + type: object + properties: + id: + type: string + name: + type: string + +ListTeamsResponse: + type: object + properties: + teams: + type: array + items: + $ref: '#/TeamObject' + total: + type: integer + page: + type: integer + limit: + type: integer + total_pages: + type: integer + description: Total number of pages + has_more: + type: boolean + description: Whether more pages are available + +# ---- Team Members ---- + +TeamMemberObject: + type: object + properties: + user_id: + type: string + user_name: + type: string + user_email: + type: string + source: + type: string + description: How the member was added (e.g. "manual", "scim_sync") + +TeamMembersResponse: + type: object + properties: + members: + type: array + items: + $ref: '#/TeamMemberObject' + +AddTeamMemberRequest: + type: object + required: + - user_id + properties: + user_id: + type: string + description: ID of the user to add to the team diff --git a/docs/providers/aliasing-models.mdx b/docs/providers/aliasing-models.mdx new file mode 100644 index 0000000000..caa17f2526 --- /dev/null +++ b/docs/providers/aliasing-models.mdx @@ -0,0 +1,345 @@ +--- +title: "Aliasing Models" +description: "Map arbitrary model names to any target identifier using static key-level aliases or dynamic routing rules." +icon: "tag" +--- + +## Overview + +Model aliasing lets you decouple the model name your application sends from the identifier Bifrost actually uses when calling a provider. You can: + +- Send `"best-model"` and have Bifrost resolve it to whatever model you've decided is best β€” without touching your application code +- Map a single logical name like `"gpt-4o"` to a provider-specific deployment name, inference profile ARN, or fine-tuned model ID +- Give different teams different underlying models behind the same name + +There are two aliasing mechanisms, and they operate at different layers: + +| | Static Aliases | Dynamic Aliases (Routing Rules) | +|---|---|---| +| **Where configured** | On a provider key | On routing rules, scoped to VK / Team / Customer / Global | +| **When applied** | After key selection, before the provider API call | At request time, before key selection | +| **Scope** | Per-key | Per-VK, per-team, per-customer, or global | +| **Condition-based** | No β€” always resolves | Yes β€” CEL expression controls when it fires | + +--- + +## Static Aliasing + +Static aliasing is available in **Bifrost v1.5.0-prerelease2 and above**. + +Static aliases are configured directly on a provider key. Every request that is served by that key will have its model name resolved through the alias map before the request reaches the provider API. + +### How it works + +1. Your application sends a request with `model: "best-model"` +2. Bifrost selects a key that supports `"best-model"` (alias names are treated as model identifiers for key selection and allowlists) +3. Before calling the provider, Bifrost resolves `"best-model"` β†’ `"gpt-4o-2024-11-20"` using that key's `aliases` map +4. The provider receives `"gpt-4o-2024-11-20"` β€” your application never needs to know + +### Configuration + +Add an `aliases` object to any key in `config.json`: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "aliases": { + "best-model": "gpt-4o-2024-11-20", + "fast-model": "gpt-4o-mini", + "embedder": "text-embedding-3-large" + } + } + ] + } + } +} +``` + +You can also add aliases via the provider keys API: + +```bash +curl -X POST http://localhost:8080/api/providers/openai/keys \ + -H "Content-Type: application/json" \ + -d '{ + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "aliases": { + "best-model": "gpt-4o-2024-11-20", + "fast-model": "gpt-4o-mini" + } + }' +``` + +The `aliases` field is a flat `string β†’ string` map. The key is what your application sends; the value is what gets forwarded to the provider. There are no restrictions on what either side can be β€” deployments, ARNs, model IDs, version hashes, fine-tune IDs, anything. + +### Validation rules + +Bifrost rejects an aliases map that violates any of these: + +- **No empty strings** β€” both the alias name and its target must be non-empty +- **No leading or trailing whitespace** on either side +- **No duplicate alias names** (checked case-insensitively) β€” `"GPT-4o"` and `"gpt-4o"` cannot both be keys in the same map + +### Case-insensitive matching + +Alias lookup is case-insensitive. If your map has `"GPT-4O": "gpt-4o-2024-11-20"` and a request comes in with `model: "gpt-4o"`, it resolves correctly. Aliases are stored as-is but matched without regard to case. + +### Tracking in responses + +Every response includes both the original name and the resolved identifier in `extra_fields`: + +```json +{ + "extra_fields": { + "original_model_requested": "best-model", + "resolved_model_used": "gpt-4o-2024-11-20", + "provider": "openai" + } +} +``` + +If no alias matches, `resolved_model_used` equals `original_model_requested`. + +--- + +## Dynamic Aliasing + +Dynamic aliasing uses [Routing Rules](/providers/routing-rules) to rewrite the model at request time based on a CEL expression. Unlike static aliases (which are fixed to a key), dynamic aliases fire conditionally and are scoped β€” so the same model name can resolve differently depending on who is making the request. + +### How scopes make it dynamic + +Routing rules are organized into four scopes, evaluated in priority order: + +``` +Virtual Key scope β†’ Team scope β†’ Customer scope β†’ Global scope +``` + +This means you can configure aliasing at any level of your org hierarchy. For example: + +- **Global scope** aliases `"best-model"` β†’ `"gpt-4o-mini"` (cost-effective default for everyone) +- **Team scope** for the AI team overrides `"best-model"` β†’ `"claude-3-5-sonnet-20241022"` (more capable) +- **Virtual Key scope** for a specific VK overrides `"best-model"` β†’ `"o1"` (highest capability, specific use case) + +Each requester gets the right model behind the same name, with zero changes to the application. + +### Example: alias based on request type + +```json +{ + "name": "route-embeddings-to-fast-model", + "cel_expression": "request_type == 'embedding' && model == 'embedder'", + "targets": [ + { "model": "text-embedding-3-small", "weight": 1.0 } + ], + "scope": "global" +} +``` + +Any request with `model: "embedder"` that is an embedding request gets routed to `"text-embedding-3-small"`. + +### Example: alias with provider switch + +```json +{ + "name": "premium-tier-routing", + "cel_expression": "headers['x-tier'] == 'premium'", + "targets": [ + { "provider": "anthropic", "model": "claude-3-5-sonnet-20241022", "weight": 1.0 } + ], + "scope": "global" +} +``` + +Premium-tier requests get routed to Anthropic's Sonnet regardless of what model the client sent. + +### Multi-step rewrites with chaining + +Setting `chain_rule: true` on a rule causes Bifrost to re-evaluate the full scope chain with the new provider/model as the new context. This lets you build layered alias resolution where a global rule establishes provider intent and a VK-scoped rule applies the final key selection. + +**Scenario:** All clients send `model: "best-model"`. Premium VKs should get `gpt-5` via a high-tier key; standard VKs should get `gpt-4.1` via a lower-tier key. + +**Rule 1 β€” Global scope (`chain_rule: true`):** +```json +{ + "name": "resolve-best-model-provider", + "cel_expression": "model == 'best-model'", + "targets": [ + { "provider": "openai", "model": "best-model", "weight": 1.0 } + ], + "scope": "global", + "chain_rule": true +} +``` + +This establishes that `best-model` resolves to OpenAI and re-evaluates the scope chain with `provider="openai", model="best-model"`. + +**Rule 2a β€” VK scope on `premium-vk` (`chain_rule: false`):** +```json +{ + "name": "premium-model-selection", + "cel_expression": "provider == 'openai' && model == 'best-model'", + "targets": [ + { "provider": "openai", "model": "gpt-5", "weight": 1.0 } + ], + "scope": "virtual_key", + "scope_id": "premium-vk" +} +``` + +**Rule 2b β€” VK scope on `standard-vk` (`chain_rule: false`):** +```json +{ + "name": "standard-model-selection", + "cel_expression": "provider == 'openai' && model == 'best-model'", + "targets": [ + { "provider": "openai", "model": "gpt-4.1", "weight": 1.0 } + ], + "scope": "virtual_key", + "scope_id": "standard-vk" +} +``` + +**What happens for a `premium-vk` request:** +``` +model="best-model" via premium-vk + ↓ Rule 1 (global, chain_rule: true) +provider="openai", model="best-model" β€” re-evaluate scope chain + ↓ Rule 2a (premium-vk scope, chain_rule: false) +provider="openai", model="gpt-5" β€” done +OpenAI receives model="gpt-5" +``` + +**What happens for a `standard-vk` request:** +``` +model="best-model" via standard-vk + ↓ Rule 1 (global, chain_rule: true) +provider="openai", model="best-model" β€” re-evaluate scope chain + ↓ Rule 2b (standard-vk scope, chain_rule: false) +provider="openai", model="gpt-4.1" β€” done +OpenAI receives model="gpt-4.1" +``` + +Each step in the chain can change provider, model, or both. Cycle detection prevents infinite loops. + +See the [Routing Rules](/providers/routing-rules) documentation for the full CEL expression reference, priority configuration, and chaining details. + +--- + +## Advanced: Combining Both Layers + +Static and dynamic aliasing compose naturally β€” routing rules fire first (at the HTTP layer), then key-level aliases resolve second (inside the inference worker, after key selection). This lets you separate concerns across two distinct layers: + +- **Routing rules** decide *which provider* and *which key tier* to use, based on who is making the request +- **Key aliases** handle *the final model identifier* forwarded to the provider + +### Example + +**Setup:** Two OpenAI keys with different tiers, each with their own `best-model` alias: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "id": "high-tier-key", + "value": "env.OPENAI_HIGH_TIER_KEY", + "models": ["*"], + "aliases": { "best-model": "gpt-5" } + }, + { + "id": "low-tier-key", + "value": "env.OPENAI_LOW_TIER_KEY", + "models": ["*"], + "aliases": { "best-model": "gpt-4o" } + } + ] + }, + "anthropic": { + "keys": [ + { + "id": "anthropic-key", + "value": "env.ANTHROPIC_KEY", + "models": ["*"], + "aliases": { "best-model": "claude-3-5-sonnet-20241022" } + } + ] + } + } +} +``` + +**Routing rules:** Two team-scoped rules handle provider selection, and two VK-scoped rules handle key tier selection. + +```json +[ + { + "name": "tech-team-provider", + "cel_expression": "model == 'best-model'", + "targets": [{ "provider": "openai", "model": "best-model", "weight": 1.0 }], + "scope": "team", + "scope_id": "tech-team", + "chain_rule": true + }, + { + "name": "ml-team-provider", + "cel_expression": "model == 'best-model'", + "targets": [{ "provider": "anthropic", "model": "best-model", "weight": 1.0 }], + "scope": "team", + "scope_id": "ml-team", + "chain_rule": true + }, + { + "name": "premium-vk-key-selection", + "cel_expression": "provider == 'openai' && model == 'best-model'", + "targets": [{ "provider": "openai", "model": "best-model", "key_id": "high-tier-key", "weight": 1.0 }], + "scope": "virtual_key", + "scope_id": "premium-vk" + }, + { + "name": "standard-vk-key-selection", + "cel_expression": "provider == 'openai' && model == 'best-model'", + "targets": [{ "provider": "openai", "model": "best-model", "key_id": "low-tier-key", "weight": 1.0 }], + "scope": "virtual_key", + "scope_id": "standard-vk" + } +] +``` + +**Resolution paths:** + +``` +tech-team + premium-vk β†’ model="best-model" + ↓ Team rule: provider="openai", model="best-model" (chain) + ↓ VK rule: key=high-tier-key + ↓ Alias: "best-model" β†’ "gpt-5" + β†’ OpenAI receives model="gpt-5" + +tech-team + standard-vk β†’ model="best-model" + ↓ Team rule: provider="openai", model="best-model" (chain) + ↓ VK rule: key=low-tier-key + ↓ Alias: "best-model" β†’ "gpt-4o" + β†’ OpenAI receives model="gpt-4o" + +ml-team β†’ model="best-model" + ↓ Team rule: provider="anthropic", model="best-model" (chain) + ↓ No VK rule matches anthropic β€” chain terminates + ↓ Alias: "best-model" β†’ "claude-3-5-sonnet-20241022" + β†’ Anthropic receives model="claude-3-5-sonnet-20241022" +``` + +**Response `extra_fields` for tech-team + premium-vk:** +```json +{ + "original_model_requested": "best-model", + "resolved_model_used": "gpt-5", + "provider": "openai" +} +``` + +`original_model_requested` is always what the client originally sent. `resolved_model_used` is the final identifier that reached the provider API β€” after both routing and alias resolution. diff --git a/docs/providers/provider-routing.mdx b/docs/providers/provider-routing.mdx index dc214adb76..7a341a8e40 100644 --- a/docs/providers/provider-routing.mdx +++ b/docs/providers/provider-routing.mdx @@ -427,19 +427,19 @@ This is particularly useful for proxy providers (OpenRouter, Vertex) where you w - + -**Key Concept**: Deployments are **key-specific** mappings that allow user-friendly model names to map to provider-specific deployment identifiers. +**Key Concept**: Aliases are **key-level** mappings that allow user-friendly model names to map to provider-specific identifiers. -**How Deployments Work**: +**How Aliases Work**: - Defined at the **Key level**, not Virtual Key level -- Structure: `deployments: {"alias": "deployment-id"}` -- **Alias** (left side): User-facing model name used in requests -- **Deployment ID** (right side): Provider-specific identifier sent to the API +- Structure: `aliases: {"user-facing-name": "provider-specific-id"}` +- **Alias key** (left side): User-facing model name used in requests +- **Provider ID** (right side): Provider-specific identifier sent to the API **Azure OpenAI Example**: -Provider configuration with deployment mapping: +Provider configuration with alias mapping: ```json { "providers": { @@ -448,13 +448,12 @@ Provider configuration with deployment mapping: { "name": "azure-prod-key", "value": "your-api-key", - "models": [], // Not used when deployments exist + "aliases": { + "gpt-4o": "my-prod-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, "azure_key_config": { - "endpoint": "https://your-resource.openai.azure.com", - "deployments": { - "gpt-4o": "my-prod-gpt4o-deployment", - "gpt-4o-mini": "my-mini-deployment" - } + "endpoint": "https://your-resource.openai.azure.com" } } ] @@ -467,9 +466,9 @@ Provider configuration with deployment mapping: 1. **Allowed models derived from aliases**: `["gpt-4o", "gpt-4o-mini"]` 2. **User requests with alias**: `{"model": "gpt-4o"}` 3. **Bifrost validates**: `gpt-4o` is in derived allowed models βœ… -4. **Bifrost maps to deployment**: `gpt-4o` β†’ `my-prod-gpt4o-deployment` +4. **Bifrost resolves alias**: `gpt-4o` β†’ `my-prod-gpt4o-deployment` 5. **Sent to Azure**: Uses `my-prod-gpt4o-deployment` as the deployment name -6. **Pricing lookup**: If pricing for deployment not found, falls back to alias `gpt-4o` +6. **Pricing lookup**: If pricing for resolved ID not found, falls back to alias `gpt-4o` **Bedrock Example with Inference Profiles**: @@ -480,15 +479,14 @@ Provider configuration with deployment mapping: "keys": [ { "name": "bedrock-key", - "models": [], + "aliases": { + "claude-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-opus": "us.anthropic.claude-3-opus-20240229-v1:0" + }, "bedrock_key_config": { "access_key": "your-access-key", "secret_key": "your-secret-key", - "region": "us-east-1", - "deployments": { - "claude-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", - "claude-opus": "us.anthropic.claude-3-opus-20240229-v1:0" - } + "region": "us-east-1" } } ] @@ -498,10 +496,10 @@ Provider configuration with deployment mapping: ``` **What Happens**: -1. **Allowed models**: `["claude-sonnet", "claude-opus"]` (from deployment aliases) +1. **Allowed models**: `["claude-sonnet", "claude-opus"]` (from alias keys) 2. **User requests**: `{"model": "claude-sonnet"}` 3. **Bifrost validates**: `claude-sonnet` in allowed models βœ… -4. **Maps to inference profile**: `claude-sonnet` β†’ `us.anthropic.claude-3-5-sonnet-20241022-v2:0` +4. **Resolves alias**: `claude-sonnet` β†’ `us.anthropic.claude-3-5-sonnet-20241022-v2:0` 5. **Sent to Bedrock**: Full ARN used in API call **Priority of Model Restrictions**: @@ -509,7 +507,7 @@ Provider configuration with deployment mapping: When determining allowed models for a key: ``` 1. If key.models is NOT empty β†’ Use key.models -2. Else if deployments exist β†’ Use deployment aliases (map keys) +2. Else if aliases exist β†’ Use alias keys 3. Else β†’ All models allowed (use Model Catalog) ``` @@ -519,11 +517,12 @@ When determining allowed models for a key: "keys": [ { "models": ["gpt-4o", "gpt-3.5-turbo"], // Explicit restriction + "aliases": { + "gpt-4o": "my-deployment", + "gpt-4-turbo": "another-deployment" // NOT accessible! + }, "azure_key_config": { - "deployments": { - "gpt-4o": "my-deployment", - "gpt-4-turbo": "another-deployment" // NOT accessible! - } + "endpoint": "https://your-resource.openai.azure.com" } } ] @@ -536,39 +535,39 @@ Result: Only `["gpt-4o", "gpt-3.5-turbo"]` allowed (models field takes priority) { "keys": [ { + "aliases": { + "claude-3-5-sonnet": "anthropic/claude-3-5-sonnet@20241022", + "gemini-pro": "google/gemini-1.5-pro" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "claude-3-5-sonnet": "anthropic/claude-3-5-sonnet@20241022", - "gemini-pro": "google/gemini-1.5-pro" - } + "region": "us-central1" } } ] } ``` -**Use Cases for Deployments**: +**Use Cases for Aliases**: - **Azure**: Map generic model names to specific deployment names in your Azure resource - **Bedrock**: Use short aliases for long inference profile ARNs - **Vertex**: Map to specific model versions or regional endpoints -- **Multi-environment**: Different deployments per key (dev/staging/prod) +- **Multi-environment**: Different aliases per key (dev/staging/prod) **Key Insight**: ``` User Request: {"model": "gpt-4o"} ↓ -Validation: Check if "gpt-4o" in allowed models (derived from deployments) +Validation: Check if "gpt-4o" in allowed models (derived from aliases) ↓ -Mapping: deployments["gpt-4o"] β†’ "my-prod-gpt4o-deployment" +Mapping: aliases["gpt-4o"] β†’ "my-prod-gpt4o-deployment" ↓ API Call: Uses "my-prod-gpt4o-deployment" as deployment ID ↓ -Pricing: Falls back to "gpt-4o" if deployment not in pricing data +Pricing: Falls back to "gpt-4o" if resolved ID not in pricing data ``` -This allows user-friendly model names in requests while supporting provider-specific deployment patterns at the key level. +This allows user-friendly model names in requests while supporting provider-specific identifier patterns at the key level. diff --git a/docs/providers/request-options.mdx b/docs/providers/request-options.mdx index 5fba8b9a0c..dd845f4cd2 100644 --- a/docs/providers/request-options.mdx +++ b/docs/providers/request-options.mdx @@ -30,6 +30,7 @@ Bifrost provides request options that control behavior, enable features, and pas | `semanticcache.CacheNoStoreKey` | `x-bf-cache-no-store` | `bool` | Prevent caching | | `mcp-include-clients` | `x-bf-mcp-include-clients` | `[]string` | Filter MCP clients (comma-separated). | | `mcp-include-tools` | `x-bf-mcp-include-tools` | `[]string` | Filter MCP tools (`clientName-toolName` format, comma-separated) | +| `BifrostContextKeyMCPExtraHeaders` | *(any header in a client's `allowed_extra_headers`)* | `map[string][]string` | Headers forwarded to MCP servers at tool execution time, filtered per-client against `allowed_extra_headers` | | `maxim.TraceIDKey` | `x-bf-maxim-trace-id` | `string` | Maxim trace ID | | `maxim.GenerationIDKey` | `x-bf-maxim-generation-id` | `string` | Maxim generation ID | | `maxim.TagsKey` | `x-bf-maxim-*` | `map[string]string` | Maxim tags (custom tag names) | diff --git a/docs/providers/routing-rules.mdx b/docs/providers/routing-rules.mdx index f57063015d..77fcdd98ce 100644 --- a/docs/providers/routing-rules.mdx +++ b/docs/providers/routing-rules.mdx @@ -8,6 +8,10 @@ icon: "chart-diagram" Routing Rules provide dynamic, expression-based control over request routing. They execute **before governance provider selection** and can override it, allowing you to make sophisticated routing decisions based on request context, headers, parameters, capacity metrics, and organizational hierarchy. + + Routing Rules Tree + + Unlike governance routing (which uses static provider weights), routing rules use **CEL expressions** (Common Expression Language) to evaluate conditions at runtime and make routing decisions dynamically. --- @@ -35,9 +39,11 @@ Global Scope (Lowest Priority, applies to all) **How it works:** 1. When a request arrives with a Virtual Key, Bifrost builds a scope chain 2. Rules are evaluated in scope order (highest to lowest) -3. The **first matching rule** wins - no further rules are evaluated +3. The **first matching rule** wins β€” no further rules are evaluated in that iteration 4. Within each scope, rules are sorted by **priority** (ascending: 0 evaluates before 10) -5. If no rule matches, the incoming provider/model is used unchanged +5. If the matched rule has `chain_rule: true`, the resolved provider/model becomes the new context and the full scope chain is re-evaluated from the top +6. If no rule matches (or the matched rule is terminal), the current decision is applied +7. If no rule ever matches, the incoming provider/model is used unchanged **Example:** ``` @@ -233,6 +239,7 @@ Access routing rules from the dashboard: - **Name** (required): Unique rule identifier - **Description** (optional): Internal notes - **Enabled**: Toggle rule on/off +- **Chain Rule**: When enabled, the routing engine re-evaluates all rules after this one matches, using the resolved provider/model as the new context. See [Rule Chaining](#rule-chaining). - **CEL Expression**: Visual or manual expression builder - **Targets** (required): One or more weighted routing targets β€” each has Provider (optional), Model (optional), API Key (optional, requires Provider to be set), and Weight (%). Weights must sum to 1. When multiple targets are defined, one is selected probabilistically at request time. - **Fallbacks** (optional): Array of fallback providers @@ -271,6 +278,7 @@ GET /api/governance/routing-rules "name": "Premium Tier Route", "description": "Route premium users to fast provider", "enabled": true, + "chain_rule": false, "cel_expression": "headers[\"x-tier\"] == \"premium\"", "targets": [ { "provider": "openai", "model": "gpt-4o", "weight": 0.7 }, @@ -404,6 +412,7 @@ Define routing rules in your `config.json` file under the governance configurati - **name** (string, required): Rule name (must be unique within scope) - **description** (string, optional): Internal documentation - **enabled** (boolean): Whether rule is active +- **chain_rule** (boolean, default: `false`): When `true`, re-evaluates the full routing chain after this rule matches, using the resolved provider/model as the new context. See [Rule Chaining](#rule-chaining). - **cel_expression** (string): CEL expression for rule matching - **targets** (array, required): One or more routing targets. Each target has: - `provider` (string, optional): Target provider β€” omit to use the incoming request provider @@ -586,6 +595,98 @@ Route based on region headers: --- +## Rule Chaining + +Rule chaining is available in **Bifrost v1.5.0-prerelease2 and above**. + +Rule chaining allows routing rules to be composed together. When a rule has `chain_rule: true`, the routing engine does not stop after it matches β€” instead, it updates the request context with the resolved provider/model and re-evaluates the full rule set from the top. + +### How Chaining Works + +``` +Request arrives (provider=openai, model=gpt-4) + ↓ +Rule 1 matches (chain_rule=true) β†’ resolves model to gpt-4-turbo + ↓ +Re-evaluate all rules with (provider=openai, model=gpt-4-turbo) + ↓ +Rule 2 matches (chain_rule=false) β†’ resolves provider to azure + ↓ +Final decision: azure / gpt-4-turbo +``` + +### Termination Conditions + +The chain stops when any of the following occurs: + +| Condition | Description | +|---|---| +| **No match** | Current iteration finds no matching rule | +| **Terminal rule** | Matched rule has `chain_rule: false` (the default) | +| **Convergence** | Provider and model are unchanged after a chain step β€” continuing would loop forever | + +### Decision Accumulation + +Each chain step overwrites the previous decision β€” the last matched rule wins for all fields: + +| Field | Strategy | +|---|---| +| Provider | Last matched rule's target | +| Model | Last matched rule's target | +| API Key | Last matched rule's target (empty = use pool) | +| Fallbacks | Last matched rule's fallbacks | + +Every chain step is logged in the routing engine audit trail for full observability. + +### Configuration Example + +```json +{ + "governance": { + "routing_rules": [ + { + "id": "normalize-alias", + "name": "Normalize gpt-4 Alias", + "enabled": true, + "chain_rule": true, + "cel_expression": "model == \"gpt-4\"", + "targets": [{ "model": "gpt-4-turbo", "weight": 1 }], + "scope": "global", + "priority": 0 + }, + { + "id": "route-gpt4-turbo", + "name": "Route gpt-4-turbo to Azure", + "enabled": true, + "chain_rule": false, + "cel_expression": "model == \"gpt-4-turbo\"", + "targets": [{ "provider": "azure", "model": "gpt-4-turbo", "weight": 1 }], + "scope": "global", + "priority": 1 + } + ] + } +} +``` + +**Result:** A request with `model=gpt-4` is normalized to `gpt-4-turbo` by Rule 1 (chain continues), then routed to Azure by Rule 2 (chain stops). + +### Use Cases + +- **Model alias normalization**: Rewrite short aliases to canonical model names before routing +- **Tiered policy application**: Apply a team-level override first, then a global key-pinning rule +- **Feature flag injection**: A chain rule sets the target to an experimental model; a downstream rule routes that model to the right provider +- **Budget-aware escalation**: A chain rule downgrades the model when budget is high; the next rule routes the downgraded model appropriately + +### Best Practices + +- Keep chains short (2–3 steps) β€” long chains are harder to reason about +- Ensure the last rule in every intended chain path is terminal (`chain_rule: false`) to prevent unintended continuation +- Use convergence detection as a safety net, not a primary termination strategy β€” if you rely on it, your rules likely have a logic gap +- Name chain rules clearly to reflect their role: "Normalize X", "Enrich context", etc. + +--- + ## Integration with Governance & Load Balancing ### Interaction with Governance Routing diff --git a/docs/providers/supported-providers/azure.mdx b/docs/providers/supported-providers/azure.mdx index 1020b0b855..46f41a8840 100644 --- a/docs/providers/supported-providers/azure.mdx +++ b/docs/providers/supported-providers/azure.mdx @@ -116,12 +116,12 @@ detects the auth environment. ```json { + "aliases": { + "gpt-4": "my-gpt4-deployment" + }, "azure_key_config": { "endpoint": "https://your-org.openai.azure.com", - "api_version": "2024-10-21", - "deployments": { - "gpt-4": "my-gpt4-deployment" - } + "api_version": "2024-10-21" } } ``` @@ -132,18 +132,18 @@ If you set `client_id`, `client_secret`, and `tenant_id`, Azure Entra ID authent ```json { + "aliases": { + "gpt-4": "my-gpt4-deployment", + "gpt-4-turbo": "my-gpt4-turbo-deployment", + "claude-3": "my-claude-deployment" + }, "azure_key_config": { "endpoint": "https://your-org.openai.azure.com", "client_id": "your-client-id", "client_secret": "your-client-secret", "tenant_id": "your-tenant-id", "scopes": ["https://cognitiveservices.azure.com/.default"], - "api_version": "2024-10-21", - "deployments": { - "gpt-4": "my-gpt4-deployment", - "gpt-4-turbo": "my-gpt4-turbo-deployment", - "claude-3": "my-claude-deployment" - } + "api_version": "2024-10-21" } } ``` @@ -156,14 +156,15 @@ If you set `client_id`, `client_secret`, and `tenant_id`, Azure Entra ID authent ```json { + "value": "your-azure-api-key", + "aliases": { + "gpt-4": "my-gpt4-deployment", + "gpt-4-turbo": "my-gpt4-turbo-deployment", + "claude-3": "my-claude-deployment" + }, "azure_key_config": { "endpoint": "https://your-org.openai.azure.com", - "api_version": "2024-10-21", - "deployments": { - "gpt-4": "my-gpt4-deployment", - "gpt-4-turbo": "my-gpt4-turbo-deployment", - "claude-3": "my-claude-deployment" - } + "api_version": "2024-10-21" } } ``` @@ -175,7 +176,7 @@ If you set `client_id`, `client_secret`, and `tenant_id`, Azure Entra ID authent - `tenant_id` - Azure Entra ID tenant ID (optional, for Service Principal auth) - `scopes` - OAuth scopes for token requests (default: `["https://cognitiveservices.azure.com/.default"]`) - `api_version` - API version to use (default: `2024-10-21`) -- `deployments` - Map of model names to deployment IDs (optional, can be provided per-request) +- `aliases` - Map of model names to Azure deployment IDs (optional, set at key level) - `allowed_models` - List of allowed models to use from this key (optional) ### Deployment Selection @@ -189,7 +190,7 @@ Deployments can be specified at three levels (in order of precedence): 2. **Key configuration** ```json - {"deployments": {"gpt-4": "my-gpt4-deployment"}} + {"aliases": {"gpt-4": "my-gpt4-deployment"}} ``` 3. **Model name** (lowest priority, if no deployment specified) diff --git a/docs/providers/supported-providers/bedrock.mdx b/docs/providers/supported-providers/bedrock.mdx index 1ca8b8efb2..52ff79ca3f 100644 --- a/docs/providers/supported-providers/bedrock.mdx +++ b/docs/providers/supported-providers/bedrock.mdx @@ -1206,7 +1206,7 @@ S3-backed file operations. Files are stored in S3 buckets integrated with Bedroc - Deployment mapping from configuration - Model allowlist support (`allowed_models` config) -**Multi-key support**: Results aggregated from all keys, filtered by `allowedModels` if configured +**Multi-key support**: Results aggregated from all keys, filtered by the key-level `models` allowlist if configured --- @@ -1280,43 +1280,43 @@ When using AWS Bedrock inference profiles or application inference profiles, you | Field | Purpose | |-------|---------| | **`arn`** | The ARN prefix (everything before the final `/resource-id`). Required for URL formation when using inference profiles. | -| **`deployments`** | Map logical model names to the **model ID or inference profile resource ID only** β€” not the full ARN. | +| **`aliases`** | Map logical model names to the **model ID or inference profile resource ID only** β€” not the full ARN. Set at the key level, not inside `bedrock_key_config`. | -**Do not** put the full ARN in the deployments mapping. The resource ID (e.g., `abc12xyz`) goes in `deployments`; the ARN prefix goes in the dedicated `arn` field. Putting the full ARN in `deployments` causes malformed URLs and `UnknownOperationException`. +**Do not** put the full ARN in the aliases mapping. The resource ID (e.g., `abc12xyz`) goes in `aliases`; the ARN prefix goes in the dedicated `arn` field inside `bedrock_key_config`. Putting the full ARN in `aliases` causes malformed URLs and `UnknownOperationException`. -**Application inference profiles** β€” use the resource ID (short alphanumeric suffix) in deployments: +**Application inference profiles** β€” use the resource ID (short alphanumeric suffix) in aliases: ```json { + "aliases": { + "claude-opus-4-6": "ghi56rst", + "claude-sonnet-4-5": "jkl78mno" + }, "bedrock_key_config": { "access_key": "your-aws-access-key", "secret_key": "your-aws-secret-key", "session_token": "optional-session-token", "region": "eu-west-1", - "arn": "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile", - "deployments": { - "claude-opus-4-6": "ghi56rst", - "claude-sonnet-4-5": "jkl78mno" - } + "arn": "arn:aws:bedrock:eu-west-1:123456789012:application-inference-profile" } } ``` -**Cross-region inference profiles** β€” use the model identifier (e.g., `us.anthropic.claude-3-5-sonnet-v1:0`) in deployments: +**Cross-region inference profiles** β€” use the model identifier (e.g., `us.anthropic.claude-3-5-sonnet-v1:0`) in aliases: ```json { + "aliases": { + "claude-sonnet": "us.anthropic.claude-3-5-sonnet-v1:0" + }, "bedrock_key_config": { "access_key": "your-aws-access-key", "secret_key": "your-aws-secret-key", "session_token": "optional-session-token", "region": "us-east-1", - "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile", - "deployments": { - "claude-sonnet": "us.anthropic.claude-3-5-sonnet-v1:0" - } + "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" } } ``` diff --git a/docs/providers/supported-providers/replicate.mdx b/docs/providers/supported-providers/replicate.mdx index bec376b731..4ae19e511e 100644 --- a/docs/providers/supported-providers/replicate.mdx +++ b/docs/providers/supported-providers/replicate.mdx @@ -77,10 +77,8 @@ Configure deployed models in the Replicate key configuration. Deployments map cu { "provider": "replicate", "value": "your-api-key", - "replicate_key_config": { - "deployments": { - "my-model": "owner/my-deployment-name" - } + "aliases": { + "my-model": "owner/my-deployment-name" } } ``` diff --git a/docs/providers/supported-providers/vertex.mdx b/docs/providers/supported-providers/vertex.mdx index 538664821a..4ea108d733 100644 --- a/docs/providers/supported-providers/vertex.mdx +++ b/docs/providers/supported-providers/vertex.mdx @@ -522,27 +522,29 @@ To provide a complete model listing experience, Bifrost performs **multi-pass mo - Custom models are identified by having deployment values that contain only digits - Example: `"deployment": "1234567890"` -2. **Second Pass - Non-Custom Models from Deployments** - - Adds standard foundation models from your `deployments` configuration +2. **Second Pass - Non-Custom Models from Aliases** + - Adds standard foundation models from your `aliases` configuration - Non-custom models have alphanumeric deployment values (e.g., `gemini-pro`, `claude-3-5-sonnet`) - - Filters by `allowedModels` if specified + - Filters by the key-level `models` allowlist, if specified - Example: `"deployment": "gemini-2.0-flash"` -3. **Third Pass - Allowed Models Not in Deployments** - - Adds models specified in `allowedModels` that weren't in the `deployments` map +3. **Third Pass - Allowed Models Not in Aliases** + - Adds models specified in `models` that weren't in the `aliases` map - Ensures all explicitly allowed models appear in the list - Uses the model name itself as the deployment value - Skips digit-only model IDs (reserved for custom models) ### Model Filtering Logic -- **If `allowedModels` is empty**: All models from all three passes are included -- **If `allowedModels` is non-empty**: Only models/deployments with keys in `allowedModels` are included +- **If `models` is empty and no aliases are configured**: No models are returned +- **If `models` is empty but aliases are configured**: Only aliased models are returned +- **If `models` is `["*"]`**: All models from all three passes are included (unrestricted) +- **If `models` is non-empty**: Only models/aliases whose request names appear in `models` are included - **Duplicate Prevention**: Each model ID is tracked to prevent duplicates across passes ### Model Name Formatting -Non-custom models from deployments and allowed models are automatically formatted for display: +Non-custom models from aliases and allowed models are automatically formatted for display: - `gemini-pro` β†’ "Gemini Pro" - `claude-3-5-sonnet` β†’ "Claude 3 5 Sonnet" @@ -557,13 +559,13 @@ Formatting uses title case and converts hyphens/underscores to spaces. ```json { + "aliases": { + "my-gemini-ft": "1234567890", + "my-claude-ft": "9876543210" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "my-gemini-ft": "1234567890", - "my-claude-ft": "9876543210" - } + "region": "us-central1" } } ``` @@ -575,33 +577,33 @@ This returns only your custom fine-tuned models from the API. ```json { + "aliases": { + "gemini-2.0-flash": "gemini-2.0-flash", + "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "gemini-2.0-flash": "gemini-2.0-flash", - "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022" - } + "region": "us-central1" } } ``` -This returns both custom models AND foundation models from deployments. +This returns both custom models AND foundation models from aliases. ```json { + "models": ["gemini-2.0-flash", "claude-3-5-sonnet"], + "aliases": { + "gemini-2.0-flash": "gemini-2.0-flash", + "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022", + "gemini-1.5-pro": "gemini-1.5-pro" + }, "vertex_key_config": { "project_id": "my-project", - "region": "us-central1", - "deployments": { - "gemini-2.0-flash": "gemini-2.0-flash", - "claude-3-5-sonnet": "claude-3-5-sonnet-v2@20241022", - "gemini-1.5-pro": "gemini-1.5-pro" - }, - "allowedModels": ["gemini-2.0-flash", "claude-3-5-sonnet"] + "region": "us-central1" } } ``` @@ -664,7 +666,7 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p **Severity**: High **Behavior**: Vertex AI's List Models API only returns custom fine-tuned models, NOT foundation models -**Impact**: Bifrost performs three-pass discovery to include foundation models from deployments and allowedModels configuration +**Impact**: Bifrost performs three-pass discovery to include foundation models from aliases and the key-level `models` allowlist **Why**: This is a Vertex AI API limitation - foundation models must be explicitly configured **Code**: `models.go:76-217` diff --git a/docs/quickstart/gateway/provider-configuration.mdx b/docs/quickstart/gateway/provider-configuration.mdx index 2e4e2b71e3..5986ed87f6 100644 --- a/docs/quickstart/gateway/provider-configuration.mdx +++ b/docs/quickstart/gateway/provider-configuration.mdx @@ -1054,7 +1054,7 @@ Azure supports three authentication methods: **Managed Identity** (DefaultAzureC #### Managed Identity / DefaultAzureCredential -Leave API key and Entra ID credentials empty. Bifrost uses `DefaultAzureCredential`, which auto-detects managed identity on Azure VMs, App Service, AKS, and similar environments. Provide only `endpoint`, `deployments`, and optionally `api_version`. +Leave API key and Entra ID credentials empty. Bifrost uses `DefaultAzureCredential`, which auto-detects managed identity on Azure VMs, App Service, AKS, and similar environments. Provide only `endpoint` and optionally `api_version`. #### Azure Entra ID (Service Principal) @@ -1070,7 +1070,7 @@ Leave API key and Entra ID credentials empty. Bifrost uses `DefaultAzureCredenti 4. Set **Client Secret**: Your Azure Entra ID client secret 5. Set **Tenant ID**: Your Azure Entra ID tenant ID 6. Set **Endpoint**: Your Azure endpoint URL -7. Configure **Deployments**: Map model names to deployment names +7. Configure **Aliases**: Map model names to deployment names 8. Set **API Version**: e.g., `2024-08-01-preview` 9. Save configuration @@ -1089,16 +1089,16 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", "client_id": "env.AZURE_CLIENT_ID", "client_secret": "env.AZURE_CLIENT_SECRET", "tenant_id": "env.AZURE_TENANT_ID", "scopes": ["https://cognitiveservices.azure.com/.default"], - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1120,16 +1120,16 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", "client_id": "env.AZURE_CLIENT_ID", "client_secret": "env.AZURE_CLIENT_SECRET", "tenant_id": "env.AZURE_TENANT_ID", "scopes": ["https://cognitiveservices.azure.com/.default"], - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1156,7 +1156,7 @@ For simpler use cases, provide the authentication credential directly in the `va 1. Navigate to **"Model Providers"** β†’ **"Configurations"** β†’ **"Azure"** 2. Set **API Key**: Your Azure API key 3. Set **Endpoint**: Your Azure endpoint URL -4. Configure **Deployments**: Map model names to deployment names +4. Configure **Aliases**: Map model names to deployment names 5. Set **API Version**: e.g., `2024-08-01-preview` 6. Save configuration @@ -1175,12 +1175,12 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "env.AZURE_API_KEY", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1202,12 +1202,12 @@ curl --location 'http://localhost:8080/api/providers' \ "value": "env.AZURE_API_KEY", "models": ["gpt-4o", "gpt-4o-mini"], "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment" + }, "azure_key_config": { "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment" - }, "api_version": "2024-08-01-preview" } } @@ -1240,8 +1240,8 @@ AWS Bedrock supports both explicit credentials and IAM role authentication: 3. Set **Access Key**: AWS Access Key ID (or leave empty to use IAM in environment) 4. Set **Secret Key**: AWS Secret Access Key (or leave empty to use IAM in environment) 5. Set **Region**: e.g., `us-east-1` -6. Configure **Deployments**: Map model names to inference profiles -7. Set **ARN**: Required for deployments mapping +6. Configure **Aliases**: Map model names to inference profiles +7. Set **ARN**: Required only when Bifrost must construct a full inference-profile ARN for an alias 8. Save configuration @@ -1256,16 +1256,16 @@ curl --location 'http://localhost:8080/api/providers' \ "keys": [ { "name": "bedrock-key-1", - "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"], + "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "claude-3-sonnet"], "weight": 1.0, + "aliases": { + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" + }, "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", "session_token": "env.AWS_SESSION_TOKEN", "region": "us-east-1", - "deployments": { - "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" - }, "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" } } @@ -1284,16 +1284,16 @@ curl --location 'http://localhost:8080/api/providers' \ "keys": [ { "name": "bedrock-key-1", - "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"], + "models": ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "claude-3-sonnet"], "weight": 1.0, + "aliases": { + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" + }, "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID", "secret_key": "env.AWS_SECRET_ACCESS_KEY", "session_token": "env.AWS_SESSION_TOKEN", "region": "us-east-1", - "deployments": { - "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0" - }, "arn": "arn:aws:bedrock:us-east-1:123456789012:inference-profile" } } @@ -1310,9 +1310,9 @@ curl --location 'http://localhost:8080/api/providers' \ **Notes:** - If using API Key authentication, set `value` field to the API key, else leave it empty for IAM role authentication. - In IAM role authentication, if both `access_key` and `secret_key` are empty, Bifrost uses IAM role authentication from the environment. -- `arn` is required for URL formation - `deployments` mapping is ignored without it. -- When using `arn` + `deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly. -- **ARN vs deployments**: Put the ARN prefix in `arn` and the model/inference profile resource ID only in `deployments` β€” never the full ARN in deployments. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. +- `arn` is required when you want Bifrost to build a full inference-profile ARN from an alias target. +- Aliases are still resolved before provider dispatch; without `arn`, the resolved alias value is sent as the Bedrock model/profile identifier directly. +- **ARN vs aliases**: Put the ARN prefix in `arn` and the model/inference profile resource ID only in the key-level `aliases` map β€” never the full ARN in alias values. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. ### Google Vertex @@ -1343,15 +1343,16 @@ curl --location 'http://localhost:8080/api/providers' \ { "name": "vertex-key-1", "value": "env.VERTEX_API_KEY", - "models": ["gemini-pro", "gemini-pro-vision"], + "models": ["gemini-pro", "gemini-pro-vision", "123456789", "fine-tuned-gemini-2.5-pro"], "weight": 1.0, + "aliases": { + "fine-tuned-gemini-2.5-pro": "123456789" + }, "vertex_key_config": { "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", "region": "us-central1", - "auth_credentials": "env.VERTEX_CREDENTIALS", - "deployments": { - "fine-tuned-gemini-2.5-pro": "123456789" - } + "auth_credentials": "env.VERTEX_CREDENTIALS" } } ] @@ -1370,15 +1371,16 @@ curl --location 'http://localhost:8080/api/providers' \ { "name": "vertex-key-1", "value": "env.VERTEX_API_KEY", - "models": ["gemini-pro", "gemini-pro-vision"], + "models": ["gemini-pro", "gemini-pro-vision", "123456789", "fine-tuned-gemini-2.5-pro"], "weight": 1.0, + "aliases": { + "fine-tuned-gemini-2.5-pro": "123456789" + }, "vertex_key_config": { "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", "region": "us-central1", - "auth_credentials": "env.VERTEX_CREDENTIALS", - "deployments": { - "fine-tuned-gemini-2.5-pro": "123456789" - } + "auth_credentials": "env.VERTEX_CREDENTIALS" } } ] @@ -1395,7 +1397,7 @@ curl --location 'http://localhost:8080/api/providers' \ - You can leave both API Key and Auth Credentials empty to use service account authentication from the environment. - You must set Project Number in Key config if using fine-tuned models. - API Key Authentication is only supported for Gemini and fine-tuned models. -- You can use custom fine-tuned models by passing `vertex/` or `vertex/` if you have set the deployments in the key config. +- You can use custom fine-tuned models by passing `vertex/` or `vertex/` if you have set the aliases on the key. Vertex AI support for fine-tuned models is currently in beta. Requests to non-Gemini fine-tuned models may fail, so please test and report any issues. diff --git a/docs/quickstart/go-sdk/provider-configuration.mdx b/docs/quickstart/go-sdk/provider-configuration.mdx index 448d5e8b39..87c2901ece 100644 --- a/docs/quickstart/go-sdk/provider-configuration.mdx +++ b/docs/quickstart/go-sdk/provider-configuration.mdx @@ -417,16 +417,17 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo Value: "", // Leave empty for Service Principal auth Models: []string{"gpt-4o", "gpt-4o-mini"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: os.Getenv("AZURE_ENDPOINT"), ClientID: bifrost.Ptr(os.Getenv("AZURE_CLIENT_ID")), ClientSecret: bifrost.Ptr(os.Getenv("AZURE_CLIENT_SECRET")), TenantID: bifrost.Ptr(os.Getenv("AZURE_TENANT_ID")), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment", - }, - APIVersion: bifrost.Ptr("2024-08-01-preview"), + Scopes: []string{"https://cognitiveservices.azure.com/.default"}, + APIVersion: bifrost.Ptr("2024-08-01-preview"), }, }, }, nil @@ -448,12 +449,12 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo Value: os.Getenv("AZURE_OPENAI_KEY"), Models: []string{"gpt-4o", "gpt-4o-mini"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o-deployment", + "gpt-4o-mini": "gpt-4o-mini-deployment", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", - "gpt-4o-mini": "gpt-4o-mini-deployment", - }, + Endpoint: os.Getenv("AZURE_ENDPOINT"), APIVersion: bifrost.Ptr("2024-08-01-preview"), }, }, @@ -479,19 +480,19 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo case schemas.Bedrock: return []schemas.Key{ { - Models: []string{"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"}, + Models: []string{"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "claude-3-sonnet"}, Weight: 1.0, Value: os.Getenv("AWS_API_KEY"), // Leave empty for IAM role authentication + // Model profiles (inference profiles): map short names to profile resource IDs + Aliases: schemas.KeyAliases{ + "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), // Leave empty for API Key authentication or system's IAM pickup SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), // Leave empty for API Key authentication or system's IAM pickup SessionToken: bifrost.Ptr(os.Getenv("AWS_SESSION_TOKEN")), // Optional Region: bifrost.Ptr("us-east-1"), - // For model profiles (inference profiles) - Deployments: map[string]string{ - "claude-3-sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0", - }, - // For direct model access without profiles + // ARN prefix for profile URLs; put resource IDs only in Aliases, not full ARNs ARN: bifrost.Ptr("arn:aws:bedrock:us-east-1:123456789012:inference-profile"), }, }, @@ -504,9 +505,9 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo **Notes:** - If using API Key authentication, set `Value` field to the API key, else leave it empty for IAM role authentication. - In IAM role authentication, if both `AccessKey` and `SecretKey` are empty, Bifrost uses IAM from the environment. -- `ARN` is required for URL formation - `Deployments` mapping is ignored without it. -- When using `ARN` + `Deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly. -- **ARN vs Deployments**: Put the ARN prefix in `ARN` and the model/inference profile resource ID only in `Deployments` β€” never the full ARN in Deployments. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. +- `ARN` is required when you want Bifrost to build a full inference-profile ARN from an alias target. +- Aliases are still resolved before provider dispatch; without `ARN`, the resolved alias value is sent as the Bedrock model/profile identifier directly. +- **ARN vs Aliases**: Put the ARN prefix in `ARN` and the model/inference profile resource ID only in `Aliases` β€” never the full ARN in alias values. See [How to Use ARNs and Application Inference Profiles](/providers/supported-providers/bedrock#how-to-use-arns-and-application-inference-profiles) for details. @@ -521,16 +522,16 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo return []schemas.Key{ { Value: os.Getenv("VERTEX_API_KEY"), // only when using gemini or fine-tuned models - Models: []string{"gemini-pro", "gemini-pro-vision"}, + Models: []string{"gemini-pro", "gemini-pro-vision", "fine-tuned-gemini-2.5-pro"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "fine-tuned-gemini-2.5-pro": "123456789", + }, VertexKeyConfig: &schemas.VertexKeyConfig{ - ProjectID: os.Getenv("VERTEX_PROJECT_ID"), // GCP project ID - ProjectNumber: os.Getenv("VERTEX_PROJECT_NUMBER"), // GCP project number (only when using fine-tuned models) - Region: "us-central1", // GCP region - AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), // Service account JSON - Deployments: map[string]string{ - "fine-tuned-gemini-2.5-pro": "123456789" - }, + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), // GCP project ID + ProjectNumber: os.Getenv("VERTEX_PROJECT_NUMBER"), // GCP project number (only when using fine-tuned models) + Region: "us-central1", // GCP region + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), // Service account JSON }, }, }, nil @@ -543,7 +544,7 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo - You can leave both API Key and Auth Credentials empty to use service account authentication from the environment. - You must set Project Number if using fine-tuned models. - API Key Authentication is only supported for Gemini and fine-tuned models. -- You can use custom fine-tuned models by passing `vertex/` or `vertex/` if you have set the deployments in the key config. +- You can use custom fine-tuned models by passing `vertex/` if you have set the aliases on the key. Vertex AI support for fine-tuned models is currently in beta. Requests to non-Gemini fine-tuned models may fail, so please test and report any issues. diff --git a/examples/configs/partial/config.json b/examples/configs/partial/config.json index f2fb269747..e748f459ce 100644 --- a/examples/configs/partial/config.json +++ b/examples/configs/partial/config.json @@ -20,7 +20,8 @@ { "name": "openai-key-1", "value": "sk-123", - "weight": 1 + "weight": 1, + "models": ["*"] } ] }, @@ -29,7 +30,8 @@ { "name": "anthropic-key-1", "value": "sk-456", - "weight": 1 + "weight": 1, + "models": ["*"] } ] }, @@ -38,12 +40,14 @@ { "name": "bedrock-key-1", "value": "ak-123", - "weight": 1 + "weight": 1, + "models": ["*"] }, { "name": "bedrock-key-2", "value": "ak-456", - "weight": 1 + "weight": 1, + "models": ["*"] } ] } diff --git a/examples/configs/withconfigstore/config.json b/examples/configs/withconfigstore/config.json index 2f0ea09a6e..c6559a0024 100644 --- a/examples/configs/withconfigstore/config.json +++ b/examples/configs/withconfigstore/config.json @@ -22,21 +22,7 @@ "provider_configs": [ { "provider": "azure", - "keys":[{ - "key_id":"8c52039e-38c6-48b2-8016-0bd884b7befb", - "value":"abc", - "name":"azure-key-1", - "weight": 0.5, - "azure_key_config":{ - "endpoint":"https://api.azure.com", - "api_version":"2024-09-01", - "deployments":{ - "gpt-4.1-2025-04-14":"gpt-4.1-2025-04-14", - "gpt-4.1-mini-2025-04-14":"gpt-4.1-mini-2025-04-14", - "gpt-4.1-nano-2025-04-14":"gpt-4.1-nano-2025-04-14" - } - } - }], + "key_ids": ["*"], "allowed_models": [ "gpt-4.1-2025-04-14", "gpt-4.1-mini-2025-04-14", diff --git a/examples/configs/withlogstore/config.json b/examples/configs/withlogstore/config.json index 613c027f35..cdcaf17d7b 100644 --- a/examples/configs/withlogstore/config.json +++ b/examples/configs/withlogstore/config.json @@ -16,7 +16,8 @@ { "name": "openai-key-1", "value": "sk-proj-abc", - "weight": 1 + "weight": 1, + "models": ["*"] } ] } diff --git a/examples/configs/withpostgresmcpclientsinconfig/config.json b/examples/configs/withpostgresmcpclientsinconfig/config.json index 600267db03..8e03969988 100644 --- a/examples/configs/withpostgresmcpclientsinconfig/config.json +++ b/examples/configs/withpostgresmcpclientsinconfig/config.json @@ -88,7 +88,9 @@ "provider_configs": [ { "provider": "openai", - "weight": 1.0 + "weight": 1.0, + "allowed_models": ["*"], + "key_ids": ["*"] } ] }, @@ -109,7 +111,9 @@ "provider_configs": [ { "provider": "openai", - "weight": 1.0 + "weight": 1.0, + "allowed_models": ["*"], + "key_ids": ["*"] } ] } @@ -130,7 +134,8 @@ { "name": "openai-primary", "value": "env.OPENAI_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ] } diff --git a/examples/configs/withpricingoverridesnostore/config.json b/examples/configs/withpricingoverridesnostore/config.json new file mode 100644 index 0000000000..cfb29ebd35 --- /dev/null +++ b/examples/configs/withpricingoverridesnostore/config.json @@ -0,0 +1,74 @@ +{ + "$schema": "https://www.getbifrost.ai/schema", + "config_store": { + "enabled": false + }, + "logs_store": { + "enabled": false + }, + "governance": { + "pricing_overrides": [ + { + "id": "override-global-gpt4o", + "name": "Global GPT-4o Pricing", + "scope_kind": "global", + "match_type": "exact", + "pattern": "gpt-4o", + "request_types": ["chat_completion"], + "pricing_patch": "{\"input_cost_per_token\":0.0000025,\"output_cost_per_token\":0.00001}" + }, + { + "id": "override-global-claude-wildcard", + "name": "Global Claude Models Pricing", + "scope_kind": "global", + "match_type": "wildcard", + "pattern": "claude-*", + "request_types": ["chat_completion"], + "pricing_patch": "{\"input_cost_per_token\":0.000003,\"output_cost_per_token\":0.000015}" + }, + { + "id": "override-provider-openai-gpt4o-mini", + "name": "OpenAI GPT-4o Mini Pricing", + "scope_kind": "provider", + "provider_id": "openai", + "match_type": "exact", + "pattern": "gpt-4o-mini", + "request_types": ["chat_completion"], + "pricing_patch": "{\"input_cost_per_token\":0.00000015,\"output_cost_per_token\":0.0000006}" + } + ] + }, + "plugins": [ + { + "name": "governance", + "enabled": true, + "config": { + "is_vk_mandatory": false + } + } + ], + "providers": { + "openai": { + "keys": [ + { + "id": "key-openai-1", + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "weight": 1, + "models": ["*"] + } + ] + }, + "anthropic": { + "keys": [ + { + "id": "key-anthropic-1", + "name": "anthropic-key-1", + "value": "env.ANTHROPIC_API_KEY", + "weight": 1, + "models": ["*"] + } + ] + } + } +} diff --git a/examples/configs/withpricingoverridessqlite/config.json b/examples/configs/withpricingoverridessqlite/config.json new file mode 100644 index 0000000000..b99094bcea --- /dev/null +++ b/examples/configs/withpricingoverridessqlite/config.json @@ -0,0 +1,82 @@ +{ + "$schema": "https://www.getbifrost.ai/schema", + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "config.db" + } + }, + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "logs.db" + } + }, + "governance": { + "pricing_overrides": [ + { + "id": "override-global-gpt4o", + "name": "Global GPT-4o Pricing", + "scope_kind": "global", + "match_type": "exact", + "pattern": "gpt-4o", + "request_types": ["chat_completion"], + "pricing_patch": "{\"input_cost_per_token\":0.0000025,\"output_cost_per_token\":0.00001}" + }, + { + "id": "override-global-claude-wildcard", + "name": "Global Claude Models Pricing", + "scope_kind": "global", + "match_type": "wildcard", + "pattern": "claude-*", + "request_types": ["chat_completion"], + "pricing_patch": "{\"input_cost_per_token\":0.000003,\"output_cost_per_token\":0.000015}" + }, + { + "id": "override-provider-openai-gpt4o-mini", + "name": "OpenAI GPT-4o Mini Pricing", + "scope_kind": "provider", + "provider_id": "openai", + "match_type": "exact", + "pattern": "gpt-4o-mini", + "request_types": ["chat_completion"], + "pricing_patch": "{\"input_cost_per_token\":0.00000015,\"output_cost_per_token\":0.0000006}" + } + ] + }, + "plugins": [ + { + "name": "governance", + "enabled": true, + "config": { + "is_vk_mandatory": false + } + } + ], + "providers": { + "openai": { + "keys": [ + { + "id": "key-openai-1", + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "weight": 1, + "models": ["*"] + } + ] + }, + "anthropic": { + "keys": [ + { + "id": "key-anthropic-1", + "name": "anthropic-key-1", + "value": "env.ANTHROPIC_API_KEY", + "weight": 1, + "models": ["*"] + } + ] + } + } +} diff --git a/examples/configs/withprompushgateway/config.json b/examples/configs/withprompushgateway/config.json index cee87e71e4..f697041388 100644 --- a/examples/configs/withprompushgateway/config.json +++ b/examples/configs/withprompushgateway/config.json @@ -7,7 +7,8 @@ "name": "OpenAI API Key", "value": "env.OPENAI_API_KEY", "weight": 1, - "use_for_batch_api": true + "use_for_batch_api": true, + "models": ["*"] } ], "network_config": { @@ -20,7 +21,8 @@ "name": "Anthropic API Key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, - "use_for_batch_api": true + "use_for_batch_api": true, + "models": ["*"] } ], "network_config": { @@ -32,7 +34,8 @@ { "value": "env.GEMINI_API_KEY", "weight": 1, - "use_for_batch_api": true + "use_for_batch_api": true, + "models": ["*"] } ], "network_config": { @@ -48,7 +51,8 @@ "region": "env.GOOGLE_LOCATION", "auth_credentials": "env.VERTEX_CREDENTIALS" }, - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -60,7 +64,8 @@ { "name": "Mistral API Key", "value": "env.MISTRAL_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -72,7 +77,8 @@ { "name": "Cohere API Key", "value": "env.COHERE_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -84,7 +90,8 @@ { "name": "Groq API Key", "value": "env.GROQ_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -96,7 +103,8 @@ { "name": "Perplexity API Key", "value": "env.PERPLEXITY_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -108,7 +116,8 @@ { "name": "Cerebras API Key", "value": "env.CEREBRAS_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -120,7 +129,8 @@ { "name": "OpenRouter API Key", "value": "env.OPENROUTER_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -136,7 +146,8 @@ "endpoint": "env.AZURE_ENDPOINT", "api_version": "env.AZURE_API_VERSION" }, - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -154,7 +165,8 @@ "arn": "env.AWS_ARN" }, "weight": 1, - "use_for_batch_api": true + "use_for_batch_api": true, + "models": ["*"] } ], "network_config": { diff --git a/examples/configs/withvirtualkeys/config.json b/examples/configs/withvirtualkeys/config.json index 7dcc8b67d2..a968bad65c 100644 --- a/examples/configs/withvirtualkeys/config.json +++ b/examples/configs/withvirtualkeys/config.json @@ -40,24 +40,27 @@ "name": "prod-assistant-us-key-01-configurations", "provider_configs": [ { - "allowed_keys": [ - "azure-us-key-1-prod" + "key_ids": [ + "key-azure-us-1-prod" ], + "allowed_models": ["*"], "provider": "azure", "weight": 0.5 }, { - "allowed_keys": [ - "vertex-us-east1-prod", - "vertex-global-prod" + "key_ids": [ + "key-vertex-us-east1-prod", + "key-vertex-global-prod" ], + "allowed_models": ["*"], "provider": "vertex", "weight": 0.5 }, { - "allowed_keys": [ - "openai-us-key-1-prod" + "key_ids": [ + "key-openai-us-1-prod" ], + "allowed_models": ["*"], "provider": "openai", "weight": 0.5 } @@ -70,24 +73,27 @@ "name": "prod-assistant-eu-key-01-configurations", "provider_configs": [ { - "allowed_keys": [ - "azure-eu-key-1-prod" + "key_ids": [ + "key-azure-eu-1-prod" ], + "allowed_models": ["*"], "provider": "azure", "weight": 0.5 }, { - "allowed_keys": [ - "vertex-europe-west1-prod", - "vertex-global-prod" + "key_ids": [ + "key-vertex-eu-west1-prod", + "key-vertex-global-prod" ], + "allowed_models": ["*"], "provider": "vertex", "weight": 0.5 }, { - "allowed_keys": [ - "bedrock-eu-central-1-prod" + "key_ids": [ + "key-bedrock-eu-central-1-prod" ], + "allowed_models": ["*"], "provider": "bedrock", "weight": 0.5 } @@ -116,6 +122,7 @@ "azure": { "keys": [ { + "id": "key-azure-us-1-prod", "azure_key_config": { "api_version": "2025-03-01-preview", "deployments": { @@ -143,6 +150,7 @@ "weight": 1 }, { + "id": "key-azure-us-2-prod", "azure_key_config": { "api_version": "2025-03-01-preview", "deployments": { @@ -170,6 +178,7 @@ "weight": 1 }, { + "id": "key-azure-eu-1-prod", "azure_key_config": { "api_version": "2025-03-01-preview", "deployments": { @@ -201,32 +210,35 @@ "bedrock": { "keys": [ { + "id": "key-bedrock-us-east-1-prod", "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID_US_EAST_1", "region": "us-east-1", "secret_key": "env.AWS_SECRET_ACCESS_KEY_US_EAST_1" }, - "models": [], + "models": ["*"], "name": "bedrock-us-east-1-prod", "weight": 1 }, { + "id": "key-bedrock-us-west-2-prod", "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID_US_WEST_2", "region": "us-west-2", "secret_key": "env.AWS_SECRET_ACCESS_KEY_US_WEST_2" }, - "models": [], + "models": ["*"], "name": "bedrock-us-west-2-prod", "weight": 1 }, { + "id": "key-bedrock-eu-central-1-prod", "bedrock_key_config": { "access_key": "env.AWS_ACCESS_KEY_ID_EU_CENTRAL_1", "region": "eu-central-1", "secret_key": "env.AWS_SECRET_ACCESS_KEY_EU_CENTRAL_1" }, - "models": [], + "models": ["*"], "name": "bedrock-eu-central-1-prod", "weight": 1 } @@ -235,6 +247,7 @@ "vertex": { "keys": [ { + "id": "key-vertex-us-east1-prod", "models": [ "google/gemini-2.5-pro", "google/gemini-2.5-flash-lite", @@ -249,6 +262,7 @@ "weight": 1 }, { + "id": "key-vertex-eu-west1-prod", "models": [ "google/gemini-2.5-pro", "google/gemini-2.5-flash-lite", @@ -263,6 +277,7 @@ "weight": 1 }, { + "id": "key-vertex-us-west1-prod", "models": [ "google/gemini-2.5-pro", "google/gemini-2.5-flash-lite", @@ -277,6 +292,7 @@ "weight": 1 }, { + "id": "key-vertex-global-prod", "models": [ "google/gemini-3-pro-preview", "google/gemini-3-flash-preview" @@ -294,9 +310,11 @@ "openai": { "keys": [ { + "id": "key-openai-us-1-prod", "name": "openai-us-key-1-prod", "value": "env.OPENAI_API_KEY_US_EAST_1", - "weight": 1 + "weight": 1, + "models": ["*"] } ] } diff --git a/examples/mcps/auth-demo-server/main.go b/examples/mcps/auth-demo-server/main.go index ca6bdd6581..4f47d8124b 100644 --- a/examples/mcps/auth-demo-server/main.go +++ b/examples/mcps/auth-demo-server/main.go @@ -7,8 +7,11 @@ package main // tools/call). A missing or wrong key is rejected before the MCP server // sees the message at all. // -// 2. TOOL-LEVEL AUTH (X-Role header) -// Enforced inside individual sensitive tool handlers. Public tools ignore it. +// 2. TOOL-EXECUTION AUTH (X-Tool-Token header) +// A separate secret token checked exclusively inside sensitive tool handlers +// at call time. Public tools ignore it; the connection middleware does not +// inspect it at all. This lets you scope a second credential to tool +// execution only β€” distinct from the connection credential. // // HOW BIFROST SENDS HEADERS // @@ -22,7 +25,8 @@ package main // This means all configured headers are present on EVERY request β€” there is no // separate "connection-only" vs "tool-only" header mechanism in Bifrost. To // distinguish the two auth levels you simply use different header names, both -// configured in the same `headers` map. +// configured in the same `headers` map. The server then enforces each header +// at the appropriate layer (middleware vs. handler). // // Bifrost config example: // @@ -32,8 +36,8 @@ package main // "connection_string": "http://localhost:3002/", // "auth_type": "headers", // "headers": { -// "X-API-Key": "super-secret-key", -// "X-Role": "admin" +// "X-API-Key": "super-secret-key", +// "X-Tool-Token": "tool-exec-secret" // }, // "tools_to_execute": ["*"] // } @@ -50,14 +54,16 @@ import ( ) const ( - // connectionAPIKey is checked in HTTP middleware on every request. + // connectionAPIKey is checked in HTTP middleware on every request + // (initialize, tools/list, tools/call). // In production, load this from an environment variable or secrets manager. connectionAPIKey = "super-secret-key" - // requiredRole is checked inside the sensitive tool handler only. - // Both X-API-Key and X-Role are configured together in Bifrost's `headers` - // map and are forwarded on every HTTP request (connection and tool calls). - requiredRole = "admin" + // toolExecToken is checked exclusively inside sensitive tool handlers β€” + // never in the connection middleware. It acts as a second independent + // credential that gates tool execution only. + // In production, load this from an environment variable or secrets manager. + toolExecToken = "tool-exec-secret" ) // contextKey is a private type so we don't collide with other packages' context keys. @@ -69,7 +75,7 @@ func main() { s := server.NewMCPServer("auth-demo-server", "1.0.0") // public_info only requires connection-level auth (X-API-Key). - // Any authenticated client can call it regardless of role. + // Any authenticated client can call it without a tool execution token. publicTool := mcp.NewTool( "public_info", mcp.WithDescription("Returns non-sensitive public information. Requires connection auth (X-API-Key) only."), @@ -77,13 +83,14 @@ func main() { ) s.AddTool(publicTool, publicInfoHandler) - // secret_data requires BOTH connection-level auth (X-API-Key) AND - // a role check (X-Role: admin) inside the handler. + // secret_data requires BOTH connection-level auth (X-API-Key) AND a + // dedicated tool-execution token (X-Tool-Token) checked inside the handler. // In Bifrost both headers live in the same `headers` map and arrive on - // every request, so the handler just reads X-Role from the context. + // every request, so the handler reads X-Tool-Token from context and + // validates it independently of the connection credential. secretTool := mcp.NewTool( "secret_data", - mcp.WithDescription("Returns sensitive data. Requires connection auth (X-API-Key) AND role check (X-Role: admin)."), + mcp.WithDescription("Returns sensitive data. Requires connection auth (X-API-Key) AND tool-execution auth (X-Tool-Token)."), mcp.WithString("resource", mcp.Required(), mcp.Description("Resource name to fetch")), ) s.AddTool(secretTool, secretDataHandler) @@ -100,10 +107,11 @@ func main() { addr := "localhost:3002" log.Printf("auth-demo-server listening on http://%s/", addr) log.Printf("\nAuth layers:") - log.Printf(" Connection-level: X-API-Key: %s (middleware rejects all requests without it)", connectionAPIKey) - log.Printf(" Tool-level: X-Role: %s (only secret_data checks this, read from context)", requiredRole) + log.Printf(" Connection-level: X-API-Key: %s (middleware rejects all requests without it)", connectionAPIKey) + log.Printf(" Tool-execution: X-Tool-Token: %s (only secret_data checks this, validated inside the handler)", toolExecToken) log.Printf("\nNote: Bifrost sends all `headers` on both connection setup AND every tool call.") - log.Printf("Both X-API-Key and X-Role go in the same `headers` map.\n") + log.Printf("Both X-API-Key and X-Tool-Token go in the same `headers` map.") + log.Printf("The server enforces each at the right layer: middleware vs. handler.\n") log.Printf("Bifrost config:") log.Printf(` { @@ -112,12 +120,12 @@ func main() { "connection_string": "http://%s/", "auth_type": "headers", "headers": { - "X-API-Key": "%s", - "X-Role": "%s" + "X-API-Key": "%s", + "X-Tool-Token": "%s" }, "tools_to_execute": ["*"] } -`, addr, connectionAPIKey, requiredRole) +`, addr, connectionAPIKey, toolExecToken) if err := http.ListenAndServe(addr, handler); err != nil { log.Fatalf("Server error: %v", err) @@ -174,21 +182,22 @@ func publicInfoHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT } // secretDataHandler handles "secret_data". Connection-level auth (X-API-Key) -// has already been verified by middleware. Here we additionally check X-Role, -// which Bifrost sends as part of the same `headers` map β€” so it is present on -// every request, including this tool call. +// has already been verified by middleware. Here we additionally check +// X-Tool-Token β€” a separate secret dedicated to authorizing tool execution. +// Bifrost sends it as part of the same `headers` map, so it arrives on every +// request including this tool call; the middleware intentionally ignores it. func secretDataHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // ── Tool-level role check ──────────────────────────────────────────────── + // ── Tool-execution token check ─────────────────────────────────────────── headers, ok := ctx.Value(requestHeadersKey).(http.Header) if !ok { return mcp.NewToolResultError("tool auth error: request headers unavailable in context"), nil } - role := headers.Get("X-Role") - if role == "" { - return mcp.NewToolResultError("tool auth required: missing X-Role header"), nil + token := headers.Get("X-Tool-Token") + if token == "" { + return mcp.NewToolResultError("tool auth required: missing X-Tool-Token header"), nil } - if role != requiredRole { - return mcp.NewToolResultError(fmt.Sprintf("tool auth failed: role %q is not authorized for this tool", role)), nil + if token != toolExecToken { + return mcp.NewToolResultError("tool auth failed: invalid X-Tool-Token"), nil } // ── Auth passed, proceed ───────────────────────────────────────────────── @@ -200,7 +209,7 @@ func secretDataHandler(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT } return mcp.NewToolResultText(fmt.Sprintf( - "Secret data for resource %q: [classified content β€” X-API-Key + X-Role:%s verified]", args.Resource, role, + "Secret data for resource %q: [classified content β€” X-API-Key + X-Tool-Token verified]", args.Resource, )), nil } diff --git a/examples/plugins/hello-world/go.mod b/examples/plugins/hello-world/go.mod index d522d70007..4a4997d9f5 100644 --- a/examples/plugins/hello-world/go.mod +++ b/examples/plugins/hello-world/go.mod @@ -2,7 +2,7 @@ module github.com/maximhq/bifrost/examples/plugins/hello-world go 1.26.1 -require github.com/maximhq/bifrost/core v1.4.17 +require github.com/maximhq/bifrost/core v1.5.1 require ( github.com/andybalholm/brotli v1.2.0 // indirect diff --git a/examples/plugins/hello-world/go.sum b/examples/plugins/hello-world/go.sum index 363f6115f0..5203db06bc 100644 --- a/examples/plugins/hello-world/go.sum +++ b/examples/plugins/hello-world/go.sum @@ -39,8 +39,8 @@ github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8 github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/examples/plugins/hello-world/main.go b/examples/plugins/hello-world/main.go index 4d6de8609c..2f464e2a7d 100644 --- a/examples/plugins/hello-world/main.go +++ b/examples/plugins/hello-world/main.go @@ -6,6 +6,12 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +const ( + transportPreHookKey schemas.BifrostContextKey = "hello-world-plugin-transport-pre-hook" + transportPostHookKey schemas.BifrostContextKey = "hello-world-plugin-transport-post-hook" + preHookKey schemas.BifrostContextKey = "hello-world-plugin-pre-hook" +) + func Init(config any) error { fmt.Println("Init called") return nil @@ -23,8 +29,9 @@ func HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) // Modify request in-place req.Headers["x-hello-world-plugin"] = "transport-pre-hook-value" // Store value in context for PreLLMHook/PostLLMHook - ctx.SetValue(schemas.BifrostContextKey("hello-world-plugin-transport-pre-hook"), "transport-pre-hook-value") + ctx.SetValue(transportPreHookKey, "transport-pre-hook-value") // Return nil to continue processing, or return &schemas.HTTPResponse{} to short-circuit + ctx.Log(schemas.LogLevelInfo, "HTTPTransportPreHook called") return nil, nil } @@ -33,7 +40,8 @@ func HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest // Modify response in-place resp.Headers["x-hello-world-plugin"] = "transport-post-hook-value" // Store value in context - ctx.SetValue(schemas.BifrostContextKey("hello-world-plugin-transport-post-hook"), "transport-post-hook-value") + ctx.Log(schemas.LogLevelInfo, "HTTPTransportPostHook called") + ctx.SetValue(transportPostHookKey, "transport-post-hook-value") // Return nil to continue processing return nil } @@ -41,6 +49,7 @@ func HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest func HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { fmt.Println("HTTPTransportStreamChunkHook called") // Modify chunk in-place + ctx.Log(schemas.LogLevelInfo, "HTTPTransportStreamChunkHook called") if chunk.BifrostChatResponse != nil && chunk.BifrostChatResponse.Choices != nil && len(chunk.BifrostChatResponse.Choices) > 0 && chunk.BifrostChatResponse.Choices[0].ChatStreamResponseChoice != nil && chunk.BifrostChatResponse.Choices[0].ChatStreamResponseChoice.Delta != nil && chunk.BifrostChatResponse.Choices[0].ChatStreamResponseChoice.Delta.Content != nil { *chunk.BifrostChatResponse.Choices[0].ChatStreamResponseChoice.Delta.Content += " - modified by hello-world-plugin" } @@ -49,19 +58,21 @@ func HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTP } func PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { - value1 := ctx.Value(schemas.BifrostContextKey("hello-world-plugin-transport-pre-hook")) + value1 := ctx.Value(transportPreHookKey) fmt.Println("value1:", value1) - ctx.SetValue(schemas.BifrostContextKey("hello-world-plugin-pre-hook"), "pre-hook-value") + ctx.SetValue(preHookKey, "pre-hook-value") + ctx.Log(schemas.LogLevelInfo, "PreLLMHook called") fmt.Println("PreLLMHook called") return req, nil, nil } func PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { fmt.Println("PostLLMHook called") - value1 := ctx.Value(schemas.BifrostContextKey("hello-world-plugin-transport-pre-hook")) + value1 := ctx.Value(transportPreHookKey) fmt.Println("value1:", value1) - value2 := ctx.Value(schemas.BifrostContextKey("hello-world-plugin-pre-hook")) + value2 := ctx.Value(preHookKey) fmt.Println("value2:", value2) + ctx.Log(schemas.LogLevelInfo, "PostLLMHook called") return resp, bifrostErr, nil } diff --git a/framework/changelog.md b/framework/changelog.md index e69de29bb2..9925e68f00 100644 --- a/framework/changelog.md +++ b/framework/changelog.md @@ -0,0 +1,16 @@ +- feat: add per-user OAuth consent flow with identity selection and MCP authentication +- feat: add access profiles for fine-grained permission control +- feat: add user level OAuth for MCP gateway +- feat: add IsSet method to EnvVar and improve provider auth validation +- feat: add session log storage and realtime request normalization +- feat: add support for tracking userId, teamId, customerId, and businessUnitId +- feat: add prompts plugin with direct key header resolver +- feat: add Fireworks AI provider support (thanks [@ivanetchart](https://github.com/ivanetchart)!) +- feat: add sorting and CSV export to virtual keys table +- feat: allow path whitelisting from security config +- fix: auto-redact env-backed values in EnvVar JSON serialization +- fix: MCP tool logs not being captured correctly +- fix: SQLite migration connections and error handling +- fix: disable SQLite foreign key checks during migration +- fix: add retry mechanism to model catalog pricing sync lock +- fix: increases buffer size for custom plugin installs from URLs diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 00970ce25e..698437c328 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "maps" "sort" "strconv" @@ -55,12 +56,14 @@ type ClientConfig struct { MCPToolExecutionTimeout int `json:"mcp_tool_execution_timeout"` // The timeout for individual tool execution in seconds MCPCodeModeBindingLevel string `json:"mcp_code_mode_binding_level"` // Code mode binding level: "server" or "tool" MCPToolSyncInterval int `json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled) + MCPDisableAutoToolInject bool `json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default HeaderFilterConfig *tables.GlobalHeaderFilterConfig `json:"header_filter_config,omitempty"` // Global header filtering configuration for x-bf-eh-* headers AsyncJobResultTTL int `json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour) RequiredHeaders []string `json:"required_headers,omitempty"` // Headers that must be present on every request (case-insensitive) LoggingHeaders []string `json:"logging_headers,omitempty"` // Headers to capture in log metadata WhitelistedRoutes []string `json:"whitelisted_routes,omitempty"` // Routes that bypass auth middleware HideDeletedVirtualKeysInFilters bool `json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys from logs/MCP filter data + RoutingChainMaxDepth int `json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) } @@ -118,6 +121,14 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("hideDeletedVirtualKeysInFilters:true")) } + // Always hash when non-zero β€” explicitly setting the default (10) is a meaningful + // config change that should be reflected in the hash. The migration that introduces + // this field backfills existing rows with RoutingChainMaxDepth=10 and regenerates + // their config_hash so there is no hash churn on upgrade for unmodified configs. + if c.RoutingChainMaxDepth > 0 { + hash.Write([]byte("routingChainMaxDepth:" + strconv.Itoa(c.RoutingChainMaxDepth))) + } + if c.MCPAgentDepth > 0 { hash.Write([]byte("mcpAgentDepth:" + strconv.Itoa(c.MCPAgentDepth))) } else { @@ -142,6 +153,11 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write([]byte("mcpToolSyncInterval:0")) } + // Only hash non-default value to avoid legacy config hash churn on upgrade. + if c.MCPDisableAutoToolInject { + hash.Write([]byte("mcpDisableAutoToolInject:true")) + } + if c.AsyncJobResultTTL > 0 { hash.Write([]byte("asyncJobResultTTL:" + strconv.Itoa(c.AsyncJobResultTTL))) } else { @@ -203,6 +219,19 @@ func (c *ClientConfig) GenerateClientConfigHash() (string, error) { hash.Write(data) } + // Hash LoggingHeaders (sorted for deterministic hashing) + if len(c.LoggingHeaders) > 0 { + sortedLogging := make([]string, len(c.LoggingHeaders)) + copy(sortedLogging, c.LoggingHeaders) + sort.Strings(sortedLogging) + data, err := sonic.Marshal(sortedLogging) + if err != nil { + return "", err + } + hash.Write([]byte("loggingHeaders:")) + hash.Write(data) + } + // Hash RequiredHeaders (sorted for deterministic hashing) if len(c.RequiredHeaders) > 0 { sortedRequired := make([]string, len(c.RequiredHeaders)) @@ -272,7 +301,6 @@ type ProviderConfig struct { StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration - PricingOverrides []schemas.ProviderPricingOverride `json:"pricing_overrides,omitempty"` // Provider-level pricing overrides ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection Status string `json:"status,omitempty"` // Model discovery status for keyless providers Description string `json:"description,omitempty"` // Model discovery error message for keyless providers @@ -293,7 +321,6 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { StoreRawRequestResponse: p.StoreRawRequestResponse, CustomProviderConfig: p.CustomProviderConfig, OpenAIConfig: p.OpenAIConfig, - PricingOverrides: p.PricingOverrides, ConfigHash: p.ConfigHash, Status: p.Status, Description: p.Description, @@ -326,6 +353,9 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { enabled := *key.Enabled redactedConfig.Keys[i].Enabled = &enabled } + if key.Aliases != nil { + redactedConfig.Keys[i].Aliases = maps.Clone(key.Aliases) + } redactedConfig.Keys[i].Value = *key.Value.Redacted() // Add back use for batch api if key.UseForBatchAPI != nil { @@ -340,9 +370,7 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { // Redact Azure key config if present if key.AzureKeyConfig != nil { - azureConfig := &schemas.AzureKeyConfig{ - Deployments: key.AzureKeyConfig.Deployments, - } + azureConfig := &schemas.AzureKeyConfig{} azureConfig.Endpoint = *key.AzureKeyConfig.Endpoint.Redacted() azureConfig.APIVersion = key.AzureKeyConfig.APIVersion if key.AzureKeyConfig.ClientID != nil { @@ -362,9 +390,7 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { // Redact Vertex key config if present if key.VertexKeyConfig != nil { - vertexConfig := &schemas.VertexKeyConfig{ - Deployments: key.VertexKeyConfig.Deployments, - } + vertexConfig := &schemas.VertexKeyConfig{} vertexConfig.ProjectID = *key.VertexKeyConfig.ProjectID.Redacted() vertexConfig.ProjectNumber = *key.VertexKeyConfig.ProjectNumber.Redacted() vertexConfig.Region = *key.VertexKeyConfig.Region.Redacted() @@ -374,9 +400,7 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { // Redact Bedrock key config if present if key.BedrockKeyConfig != nil { - bedrockConfig := &schemas.BedrockKeyConfig{ - Deployments: key.BedrockKeyConfig.Deployments, - } + bedrockConfig := &schemas.BedrockKeyConfig{} bedrockConfig.AccessKey = *key.BedrockKeyConfig.AccessKey.Redacted() bedrockConfig.SecretKey = *key.BedrockKeyConfig.SecretKey.Redacted() if key.BedrockKeyConfig.SessionToken != nil { @@ -404,13 +428,6 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { redactedConfig.Keys[i].BedrockKeyConfig = bedrockConfig } - if key.ReplicateKeyConfig != nil { - replicateConfig := &schemas.ReplicateKeyConfig{ - Deployments: key.ReplicateKeyConfig.Deployments, - } - redactedConfig.Keys[i].ReplicateKeyConfig = replicateConfig - } - if key.VLLMKeyConfig != nil { vllmConfig := &schemas.VLLMKeyConfig{ ModelName: key.VLLMKeyConfig.ModelName, @@ -418,6 +435,25 @@ func (p *ProviderConfig) Redacted() *ProviderConfig { vllmConfig.URL = *key.VLLMKeyConfig.URL.Redacted() redactedConfig.Keys[i].VLLMKeyConfig = vllmConfig } + + if key.ReplicateKeyConfig != nil { + replicateConfig := &schemas.ReplicateKeyConfig{ + UseDeploymentsEndpoint: key.ReplicateKeyConfig.UseDeploymentsEndpoint, + } + redactedConfig.Keys[i].ReplicateKeyConfig = replicateConfig + } + + if key.OllamaKeyConfig != nil { + ollamaConfig := &schemas.OllamaKeyConfig{} + ollamaConfig.URL = *key.OllamaKeyConfig.URL.Redacted() + redactedConfig.Keys[i].OllamaKeyConfig = ollamaConfig + } + + if key.SGLKeyConfig != nil { + sglConfig := &schemas.SGLKeyConfig{} + sglConfig.URL = *key.SGLKeyConfig.URL.Redacted() + redactedConfig.Keys[i].SGLKeyConfig = sglConfig + } } return &redactedConfig } @@ -476,15 +512,6 @@ func (p *ProviderConfig) GenerateConfigHash(providerName string) (string, error) hash.Write(data) } - // Hash PricingOverrides - if p.PricingOverrides != nil { - data, err := sonic.Marshal(p.PricingOverrides) - if err != nil { - return "", err - } - hash.Write(data) - } - // Hash SendBackRawRequest if p.SendBackRawRequest { hash.Write([]byte("sendBackRawRequest")) @@ -569,9 +596,9 @@ func GenerateKeyHash(key schemas.Key) (string, error) { } hash.Write(data) } - // Hash ReplicateKeyConfig - if key.ReplicateKeyConfig != nil { - data, err := sonic.Marshal(key.ReplicateKeyConfig) + // Hash Aliases + if key.Aliases != nil { + data, err := sonic.Marshal(key.Aliases) if err != nil { return "", err } @@ -585,6 +612,30 @@ func GenerateKeyHash(key schemas.Key) (string, error) { } hash.Write(data) } + // Hash ReplicateKeyConfig + if key.ReplicateKeyConfig != nil { + data, err := sonic.Marshal(key.ReplicateKeyConfig) + if err != nil { + return "", err + } + hash.Write(data) + } + // Hash OllamaKeyConfig + if key.OllamaKeyConfig != nil { + data, err := sonic.Marshal(key.OllamaKeyConfig) + if err != nil { + return "", err + } + hash.Write(data) + } + // Hash SGLKeyConfig + if key.SGLKeyConfig != nil { + data, err := sonic.Marshal(key.SGLKeyConfig) + if err != nil { + return "", err + } + hash.Write(data) + } // Hash Enabled (nil = false, only true produces different hash) if key.Enabled != nil && *key.Enabled { hash.Write([]byte("enabled:true")) @@ -609,7 +660,6 @@ type VirtualKeyHashInput struct { IsActive bool TeamID *string CustomerID *string - BudgetID *string RateLimitID *string // ProviderConfigs and MCPConfigs are hashed separately as they contain nested data ProviderConfigs []VirtualKeyProviderConfigHashInput @@ -619,9 +669,8 @@ type VirtualKeyHashInput struct { // VirtualKeyProviderConfigHashInput represents provider config fields for hashing type VirtualKeyProviderConfigHashInput struct { Provider string - Weight float64 + Weight *float64 AllowedModels []string - BudgetID *string RateLimitID *string KeyIDs []string // Only key IDs, not full key objects } @@ -657,10 +706,6 @@ func GenerateVirtualKeyHash(vk tables.TableVirtualKey) (string, error) { if vk.CustomerID != nil { hash.Write([]byte("customerID:" + *vk.CustomerID)) } - // Hash BudgetID - if vk.BudgetID != nil { - hash.Write([]byte("budgetID:" + *vk.BudgetID)) - } // Hash RateLimitID if vk.RateLimitID != nil { hash.Write([]byte("rateLimitID:" + *vk.RateLimitID)) @@ -674,16 +719,6 @@ func GenerateVirtualKeyHash(vk tables.TableVirtualKey) (string, error) { if sortedProviderConfigs[i].Provider != sortedProviderConfigs[j].Provider { return sortedProviderConfigs[i].Provider < sortedProviderConfigs[j].Provider } - bi, bj := "", "" - if sortedProviderConfigs[i].BudgetID != nil { - bi = *sortedProviderConfigs[i].BudgetID - } - if sortedProviderConfigs[j].BudgetID != nil { - bj = *sortedProviderConfigs[j].BudgetID - } - if bi != bj { - return bi < bj - } ri, rj := "", "" if sortedProviderConfigs[i].RateLimitID != nil { ri = *sortedProviderConfigs[i].RateLimitID @@ -694,7 +729,14 @@ func GenerateVirtualKeyHash(vk tables.TableVirtualKey) (string, error) { if ri != rj { return ri < rj } - return getWeight(sortedProviderConfigs[i].Weight) < getWeight(sortedProviderConfigs[j].Weight) + wi, wj := sortedProviderConfigs[i].Weight, sortedProviderConfigs[j].Weight + if (wi == nil) != (wj == nil) { + return wi == nil + } + if wi != nil && wj != nil && *wi != *wj { + return *wi < *wj + } + return false }) // Filter out provider configs that are not available providerConfigsForHash := make([]VirtualKeyProviderConfigHashInput, len(sortedProviderConfigs)) @@ -712,9 +754,8 @@ func GenerateVirtualKeyHash(vk tables.TableVirtualKey) (string, error) { sort.Strings(sortedAllowedModels) providerConfigsForHash[i] = VirtualKeyProviderConfigHashInput{ Provider: pc.Provider, - Weight: getWeight(pc.Weight), + Weight: pc.Weight, AllowedModels: sortedAllowedModels, - BudgetID: pc.BudgetID, RateLimitID: pc.RateLimitID, KeyIDs: keyIDs, } @@ -992,6 +1033,13 @@ func GenerateRoutingRuleHash(r tables.TableRoutingRule) (string, error) { hash.Write(data) } + // Hash ChainRule + if r.ChainRule { + hash.Write([]byte("chain_rule:true")) + } else { + hash.Write([]byte("chain_rule:false")) + } + // Hash Scope hash.Write([]byte(r.Scope)) @@ -1008,6 +1056,23 @@ func GenerateRoutingRuleHash(r tables.TableRoutingRule) (string, error) { return hex.EncodeToString(hash.Sum(nil)), nil } +// GeneratePricingOverrideHash generates a SHA256 hash for a pricing override. +// Skips: CreatedAt, UpdatedAt, ConfigHash (dynamic/meta fields). +func GeneratePricingOverrideHash(p tables.TablePricingOverride) (string, error) { + hash := sha256.New() + hash.Write([]byte(p.ID)) + hash.Write([]byte(p.Name)) + hash.Write([]byte(p.ScopeKind)) + hash.Write([]byte(derefStr(p.VirtualKeyID))) + hash.Write([]byte(derefStr(p.ProviderID))) + hash.Write([]byte(derefStr(p.ProviderKeyID))) + hash.Write([]byte(p.MatchType)) + hash.Write([]byte(p.Pattern)) + hash.Write([]byte(p.RequestTypesJSON)) + hash.Write([]byte(p.PricingPatchJSON)) + return hex.EncodeToString(hash.Sum(nil)), nil +} + // GenerateMCPClientHash generates a SHA256 hash for an MCP client. // This is used to detect changes to MCP clients between config.json and database. // Skips: ID (autoIncrement), CreatedAt, UpdatedAt (dynamic fields) @@ -1131,14 +1196,17 @@ type AuthConfig struct { // ConfigMap maps provider names to their configurations. type ConfigMap map[schemas.ModelProvider]ProviderConfig +// GovernanceConfig contains governance entities loaded from the config store or +// reconciled from config.json. type GovernanceConfig struct { - VirtualKeys []tables.TableVirtualKey `json:"virtual_keys"` - Teams []tables.TableTeam `json:"teams"` - Customers []tables.TableCustomer `json:"customers"` - Budgets []tables.TableBudget `json:"budgets"` - RateLimits []tables.TableRateLimit `json:"rate_limits"` - ModelConfigs []tables.TableModelConfig `json:"model_configs"` - Providers []tables.TableProvider `json:"providers"` - RoutingRules []tables.TableRoutingRule `json:"routing_rules"` - AuthConfig *AuthConfig `json:"auth_config,omitempty"` + VirtualKeys []tables.TableVirtualKey `json:"virtual_keys"` + Teams []tables.TableTeam `json:"teams"` + Customers []tables.TableCustomer `json:"customers"` + Budgets []tables.TableBudget `json:"budgets"` + RateLimits []tables.TableRateLimit `json:"rate_limits"` + ModelConfigs []tables.TableModelConfig `json:"model_configs"` + Providers []tables.TableProvider `json:"providers"` + RoutingRules []tables.TableRoutingRule `json:"routing_rules"` + PricingOverrides []tables.TablePricingOverride `json:"pricing_overrides,omitempty"` + AuthConfig *AuthConfig `json:"auth_config,omitempty"` } diff --git a/framework/configstore/clientconfig_redaction_test.go b/framework/configstore/clientconfig_redaction_test.go new file mode 100644 index 0000000000..8bff430fb0 --- /dev/null +++ b/framework/configstore/clientconfig_redaction_test.go @@ -0,0 +1,242 @@ +package configstore + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestProviderConfig_Redacted_AutoMasksEnvBackedFields verifies that env-backed +// values in any provider config field are automatically redacted in the JSON output +// of a Redacted() ProviderConfig β€” even fields that don't have explicit Redacted() +// calls (like Azure APIVersion). This is the defense-in-depth guarantee provided +// by EnvVar.MarshalJSON. +func TestProviderConfig_Redacted_AutoMasksEnvBackedFields(t *testing.T) { + t.Setenv("MY_AZURE_API_VERSION_SECRET", "2024-10-21-preview-secret") + + apiVersion := schemas.NewEnvVar("env.MY_AZURE_API_VERSION_SECRET") + require.True(t, apiVersion.IsFromEnv(), "setup: APIVersion should be FromEnv") + require.Equal(t, "2024-10-21-preview-secret", apiVersion.GetValue(), + "setup: APIVersion should be resolved") + + config := ProviderConfig{ + Keys: []schemas.Key{{ + ID: "k1", + Name: "test", + Value: schemas.EnvVar{Val: ""}, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"), + APIVersion: apiVersion, + }, + }}, + } + + redacted := config.Redacted() + require.NotNil(t, redacted) + require.Len(t, redacted.Keys, 1) + require.NotNil(t, redacted.Keys[0].AzureKeyConfig) + require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion) + + // Marshal the APIVersion field as it would be sent to the UI. + data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion) + require.NoError(t, err) + + var out struct { + Value string `json:"value"` + EnvVar string `json:"env_var"` + FromEnv bool `json:"from_env"` + } + require.NoError(t, json.Unmarshal(data, &out)) + + assert.NotContains(t, out.Value, "preview-secret", + "resolved env value leaked through APIVersion JSON output: %q", out.Value) + assert.Equal(t, "env.MY_AZURE_API_VERSION_SECRET", out.EnvVar, + "env var reference must be preserved so the UI can show it") + assert.True(t, out.FromEnv, "from_env flag must be preserved") +} + +// TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields verifies that the +// auto-redaction does NOT touch plain (non-env-backed) values. A user-typed +// api_version like "2024-10-21" must show as-is in the UI. +func TestProviderConfig_Redacted_DoesNotMaskPlainNonSecretFields(t *testing.T) { + config := ProviderConfig{ + Keys: []schemas.Key{{ + ID: "k1", + Name: "test", + Value: schemas.EnvVar{Val: ""}, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar("https://foo.openai.azure.com"), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }}, + } + + redacted := config.Redacted() + require.NotNil(t, redacted) + require.Len(t, redacted.Keys, 1) + require.NotNil(t, redacted.Keys[0].AzureKeyConfig) + require.NotNil(t, redacted.Keys[0].AzureKeyConfig.APIVersion) + + data, err := json.Marshal(redacted.Keys[0].AzureKeyConfig.APIVersion) + require.NoError(t, err) + + var out struct { + Value string `json:"value"` + FromEnv bool `json:"from_env"` + } + require.NoError(t, json.Unmarshal(data, &out)) + + assert.Equal(t, "2024-10-21", out.Value, + "plain APIVersion was incorrectly redacted") + assert.False(t, out.FromEnv) +} + +// TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex verifies that +// env-backed Vertex fields appear in the redacted output with the env reference +// intact and the resolved value masked. This is the user-facing fix for the +// "I see resolved env values in the UI" bug. +func TestProviderConfig_Redacted_PreservesEnvVarReferenceForVertex(t *testing.T) { + t.Setenv("MY_VERTEX_PROJECT_ID_SECRET", "super-secret-project-12345") + + projectID := schemas.NewEnvVar("env.MY_VERTEX_PROJECT_ID_SECRET") + require.Equal(t, "super-secret-project-12345", projectID.GetValue()) + + config := ProviderConfig{ + Keys: []schemas.Key{{ + ID: "k1", + Name: "test", + Value: schemas.EnvVar{Val: ""}, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *projectID, + Region: *schemas.NewEnvVar("us-central1"), + }, + }}, + } + + redacted := config.Redacted() + data, err := json.Marshal(redacted.Keys[0].VertexKeyConfig.ProjectID) + require.NoError(t, err) + + var out struct { + Value string `json:"value"` + EnvVar string `json:"env_var"` + FromEnv bool `json:"from_env"` + } + require.NoError(t, json.Unmarshal(data, &out)) + + assert.NotContains(t, out.Value, "super-secret-project", + "resolved Vertex ProjectID env value leaked: %q", out.Value) + assert.Equal(t, "env.MY_VERTEX_PROJECT_ID_SECRET", out.EnvVar) + assert.True(t, out.FromEnv) +} + +// TestProviderConfig_Redacted_DoesNotMutateOriginal ensures Redacted() and the +// subsequent JSON marshaling do not mutate the original config in memory. The +// inference path reads from the in-memory config and calls GetValue() to build +// outgoing LLM requests; if Redacted() or MarshalJSON were to mutate state, every +// inference request after a UI fetch would silently start using masked values. +func TestProviderConfig_Redacted_DoesNotMutateOriginal(t *testing.T) { + t.Setenv("MY_REAL_KEY", "sk-real-secret-1234567890abcdef") + + keyValue := schemas.NewEnvVar("env.MY_REAL_KEY") + require.Equal(t, "sk-real-secret-1234567890abcdef", keyValue.GetValue()) + + config := ProviderConfig{ + Keys: []schemas.Key{{ + ID: "k1", + Name: "test", + Value: *keyValue, + }}, + } + + redacted := config.Redacted() + _, err := json.Marshal(redacted) + require.NoError(t, err) + + // Original must still hold the resolved value. + assert.Equal(t, "sk-real-secret-1234567890abcdef", config.Keys[0].Value.GetValue(), + "Redacted() or MarshalJSON mutated the original key Value") +} + +// TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets is a high-level smoke +// test: build a config containing env-backed values across multiple provider types +// and assert that no resolved secret string appears anywhere in the marshaled +// redacted JSON. +func TestProviderConfig_Redacted_FullJSONHasNoLeakedEnvSecrets(t *testing.T) { + t.Setenv("LEAK_TEST_AZURE_ENDPOINT", "https://leaked-azure.example.com") + t.Setenv("LEAK_TEST_AZURE_APIVER", "leaked-api-version-string") + t.Setenv("LEAK_TEST_VERTEX_PROJECT", "leaked-vertex-project-id") + t.Setenv("LEAK_TEST_BEDROCK_ACCESS", "AKIAIOSFODNN7LEAKED1") + t.Setenv("LEAK_TEST_OPENAI_KEY", "sk-leaked-openai-key-1234567890") + + config := ProviderConfig{ + Keys: []schemas.Key{ + { + ID: "openai-k", + Name: "openai", + Value: *schemas.NewEnvVar("env.LEAK_TEST_OPENAI_KEY"), + }, + { + ID: "azure-k", + Name: "azure", + Value: schemas.EnvVar{Val: ""}, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar("env.LEAK_TEST_AZURE_ENDPOINT"), + APIVersion: schemas.NewEnvVar("env.LEAK_TEST_AZURE_APIVER"), + }, + }, + { + ID: "vertex-k", + Name: "vertex", + Value: schemas.EnvVar{Val: ""}, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.LEAK_TEST_VERTEX_PROJECT"), + Region: *schemas.NewEnvVar("us-central1"), + }, + }, + { + ID: "bedrock-k", + Name: "bedrock", + Value: schemas.EnvVar{Val: ""}, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + AccessKey: *schemas.NewEnvVar("env.LEAK_TEST_BEDROCK_ACCESS"), + SecretKey: schemas.EnvVar{Val: ""}, + }, + }, + }, + } + + redacted := config.Redacted() + data, err := json.Marshal(redacted) + require.NoError(t, err) + jsonStr := string(data) + + leakedSecrets := []string{ + "https://leaked-azure.example.com", + "leaked-api-version-string", + "leaked-vertex-project-id", + "AKIAIOSFODNN7LEAKED1", + "sk-leaked-openai-key-1234567890", + } + for _, secret := range leakedSecrets { + assert.False(t, strings.Contains(jsonStr, secret), + "resolved env secret %q leaked into redacted JSON output", secret) + } + + // And the env var references must be present so the UI can render them. + expectedRefs := []string{ + "env.LEAK_TEST_OPENAI_KEY", + "env.LEAK_TEST_AZURE_ENDPOINT", + "env.LEAK_TEST_AZURE_APIVER", + "env.LEAK_TEST_VERTEX_PROJECT", + "env.LEAK_TEST_BEDROCK_ACCESS", + } + for _, ref := range expectedRefs { + assert.True(t, strings.Contains(jsonStr, ref), + "env var reference %q missing from redacted JSON output", ref) + } +} diff --git a/framework/configstore/encryption_test.go b/framework/configstore/encryption_test.go index 0d1f3625cd..9ac36baede 100644 --- a/framework/configstore/encryption_test.go +++ b/framework/configstore/encryption_test.go @@ -725,7 +725,7 @@ func TestEncryptPlaintextKeys_BedrockFields_EncryptsAndDecryptsCorrectly(t *test now := time.Now().UTC().Format("2006-01-02 15:04:05") insertPlaintextRow(t, db, - `INSERT INTO config_keys (name, provider_id, provider, key_id, value, bedrock_access_key, bedrock_secret_key, bedrock_session_token, bedrock_region, bedrock_arn, bedrock_deployments_json, bedrock_batch_s3_config_json, encryption_status, created_at, updated_at) + `INSERT INTO config_keys (name, provider_id, provider, key_id, value, bedrock_access_key, bedrock_secret_key, bedrock_session_token, bedrock_region, bedrock_arn, aliases_json, bedrock_batch_s3_config_json, encryption_status, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'plain_text', ?, ?)`, "bedrock-key", 1, "bedrock", "br-1", "sk-bedrock-key-value", "AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "FwoGZXIvYXdzEBYaDH7sampleSessionToken", @@ -747,9 +747,15 @@ func TestEncryptPlaintextKeys_BedrockFields_EncryptsAndDecryptsCorrectly(t *test assert.NotEqual(t, "FwoGZXIvYXdzEBYaDH7sampleSessionToken", raw["bedrock_session_token"]) assert.NotEqual(t, "us-west-2", raw["bedrock_region"]) assert.NotEqual(t, "arn:aws:iam::123456789:role/bedrock", raw["bedrock_arn"]) - if rawDeploy, ok := raw["bedrock_deployments_json"].(string); ok { - assert.NotContains(t, rawDeploy, "profile-claude") + rawAliasesVal := raw["aliases_json"] + var rawAliasesStr string + switch v := rawAliasesVal.(type) { + case string: + rawAliasesStr = v + case []byte: + rawAliasesStr = string(v) } + assert.NotContains(t, rawAliasesStr, "profile-claude") if rawBatch, ok := raw["bedrock_batch_s3_config_json"].(string); ok { assert.NotContains(t, rawBatch, "my-bucket") } @@ -767,7 +773,7 @@ func TestEncryptPlaintextKeys_BedrockFields_EncryptsAndDecryptsCorrectly(t *test assert.Equal(t, "us-west-2", found.BedrockKeyConfig.Region.GetValue()) require.NotNil(t, found.BedrockKeyConfig.ARN) assert.Equal(t, "arn:aws:iam::123456789:role/bedrock", found.BedrockKeyConfig.ARN.GetValue()) - assert.Equal(t, "profile-claude", found.BedrockKeyConfig.Deployments["claude-3"]) + assert.Equal(t, "profile-claude", found.Aliases["claude-3"]) require.NotNil(t, found.BedrockKeyConfig.BatchS3Config) require.Len(t, found.BedrockKeyConfig.BatchS3Config.Buckets, 1) assert.Equal(t, "my-bucket", found.BedrockKeyConfig.BatchS3Config.Buckets[0].BucketName) diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 908824d754..e64351eeaa 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -274,14 +274,13 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddEnforceAuthOnInferenceColumn(ctx, db); err != nil { return err } - if err := migrationAddProviderPricingOverridesColumn(ctx, db); err != nil { + if err := migrationReconcilePricingOverridesTable(ctx, db); err != nil { return err } if err := migrationAddEncryptionColumns(ctx, db); err != nil { return err } if err := migrationAddOutputCostPerVideoPerSecond(ctx, db); err != nil { - return err } if err := migrationDropEnableGovernanceColumn(ctx, db); err != nil { @@ -317,18 +316,63 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddPluginOrderColumns(ctx, db); err != nil { return err } + if err := migrationAddAllowAllKeysToProviderConfig(ctx, db); err != nil { + return err + } + if err := migrationBackfillEmptyVirtualKeyConfigs(ctx, db); err != nil { + return err + } + if err := migrationAddMCPDisableAutoToolInjectColumn(ctx, db); err != nil { + return err + } + if err := migrationBackfillAllowedModelsWildcard(ctx, db); err != nil { + return err + } + if err := migrationAddMCPClientAllowedExtraHeadersJSONColumn(ctx, db); err != nil { + return err + } + if err := migrationMakeBasePricingColumnsNullable(ctx, db); err != nil { + return err + } + if err := migrationAddAllowOnAllVirtualKeysColumn(ctx, db); err != nil { + return err + } if err := migrationAddOpenAIConfigJSONColumn(ctx, db); err != nil { return err } if err := migrationAddKeyBlacklistedModelsJSONColumn(ctx, db); err != nil { return err } + if err := migrationAddChainRuleColumnToRoutingRules(ctx, db); err != nil { + return err + } + if err := migrationDropDeploymentColumnsAndAddAliases(ctx, db); err != nil { + return err + } + if err := migrationAddReplicateKeyConfigColumn(ctx, db); err != nil { + return err + } if err := migrationAddBudgetCalendarAlignedColumn(ctx, db); err != nil { return err } + if err := migrationAddRoutingChainMaxDepthColumn(ctx, db); err != nil { + return err + } if err := migrationAddModelCapabilityColumns(ctx, db); err != nil { return err } + if err := migrationAddOllamaSGLConfigColumns(ctx, db); err != nil { + return err + } + if err := migrationAddMultiBudgetTables(ctx, db); err != nil { + return err + } + if err := migrationAddPerUserOAuthTables(ctx, db); err != nil { + return err + } + if err := migrationAddMCPClientDiscoveredToolsColumns(ctx, db); err != nil { + return err + } if err := migrationAddWhitelistedRoutesJSONColumn(ctx, db); err != nil { return err } @@ -358,7 +402,6 @@ func migrationAddStoreRawRequestResponseColumn(ctx context.Context, db *gorm.DB) "concurrency_buffer_json", "proxy_config_json", "custom_provider_config_json", - "pricing_overrides_json", "send_back_raw_request", "send_back_raw_response", "store_raw_request_response", @@ -376,7 +419,6 @@ func migrationAddStoreRawRequestResponseColumn(ctx context.Context, db *gorm.DB) SendBackRawResponse: provider.SendBackRawResponse, StoreRawRequestResponse: provider.StoreRawRequestResponse, CustomProviderConfig: provider.CustomProviderConfig, - PricingOverrides: provider.PricingOverrides, } // Here the default value of store_raw_request_response should be based on the default value of SendBackRawRequest and SendBackRawResponse if provider.SendBackRawRequest || provider.SendBackRawResponse { @@ -514,6 +556,11 @@ func migrationInit(ctx context.Context, db *gorm.DB) error { return err } } + if !migrator.HasTable(&tables.TablePricingOverride{}) { + if err := migrator.CreateTable(&tables.TablePricingOverride{}); err != nil { + return err + } + } if !migrator.HasTable(&tables.TablePlugin{}) { if err := migrator.CreateTable(&tables.TablePlugin{}); err != nil { return err @@ -571,6 +618,9 @@ func migrationInit(ctx context.Context, db *gorm.DB) error { if err := migrator.DropTable(&tables.TableModelPricing{}); err != nil { return err } + if err := migrator.DropTable(&tables.TablePricingOverride{}); err != nil { + return err + } if err := migrator.DropTable(&tables.TablePlugin{}); err != nil { return err } @@ -901,8 +951,8 @@ func migrationCleanupMCPClientToolsConfig(ctx context.Context, db *gorm.DB) erro // Step 2: Update empty ToolsToExecuteJSON arrays to wildcard ["*"] // Convert "[]" (empty array) to "[\"*\"]" (wildcard array) for backward compatibility updateSQL := ` - UPDATE config_mcp_clients - SET tools_to_execute_json = '["*"]' + UPDATE config_mcp_clients + SET tools_to_execute_json = '["*"]' WHERE tools_to_execute_json = '[]' OR tools_to_execute_json = '' OR tools_to_execute_json IS NULL ` if err := tx.Exec(updateSQL).Error; err != nil { @@ -917,8 +967,8 @@ func migrationCleanupMCPClientToolsConfig(ctx context.Context, db *gorm.DB) erro tx = tx.WithContext(ctx) revertSQL := ` - UPDATE config_mcp_clients - SET tools_to_execute_json = '[]' + UPDATE config_mcp_clients + SET tools_to_execute_json = '[]' WHERE tools_to_execute_json = '["*"]' ` if err := tx.Exec(revertSQL).Error; err != nil { @@ -973,12 +1023,13 @@ func migrationAddProviderConfigBudgetRateLimit(ctx context.Context, db *gorm.DB) tx = tx.WithContext(ctx) migrator := tx.Migrator() - // Add BudgetID column if it doesn't exist + // Add budget_id and rate_limit_id columns if they don't exist + // Note: budget_id is added via raw SQL because the field was later removed from the struct + // (migrated to governance_budgets.provider_config_id in add_multi_budget_tables) if migrator.HasTable(&tables.TableVirtualKeyProviderConfig{}) { - if !migrator.HasColumn(&tables.TableVirtualKeyProviderConfig{}, "budget_id") { - if err := migrator.AddColumn(&tables.TableVirtualKeyProviderConfig{}, "budget_id"); err != nil { - return fmt.Errorf("failed to add budget_id column: %w", err) - } + if err := tx.Exec("ALTER TABLE governance_virtual_key_provider_configs ADD COLUMN IF NOT EXISTS budget_id VARCHAR(255)").Error; err != nil { + // Ignore error for databases that don't support IF NOT EXISTS (e.g., SQLite) + // The column may already exist from a previous run } // Add RateLimitID column if it doesn't exist @@ -989,10 +1040,8 @@ func migrationAddProviderConfigBudgetRateLimit(ctx context.Context, db *gorm.DB) } // Create foreign key indexes for better performance - if !migrator.HasIndex(&tables.TableVirtualKeyProviderConfig{}, "idx_provider_config_budget") { - if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_provider_config_budget ON governance_virtual_key_provider_configs (budget_id)").Error; err != nil { - return fmt.Errorf("failed to create budget_id index: %w", err) - } + if err := tx.Exec("CREATE INDEX IF NOT EXISTS idx_provider_config_budget ON governance_virtual_key_provider_configs (budget_id)").Error; err != nil { + // Ignore - index may already exist or column may not exist yet } if !migrator.HasIndex(&tables.TableVirtualKeyProviderConfig{}, "idx_provider_config_rate_limit") { @@ -1001,12 +1050,7 @@ func migrationAddProviderConfigBudgetRateLimit(ctx context.Context, db *gorm.DB) } } - // Create FK constraints (dialect‑agnostic) - if !migrator.HasConstraint(&tables.TableVirtualKeyProviderConfig{}, "Budget") { - if err := migrator.CreateConstraint(&tables.TableVirtualKeyProviderConfig{}, "Budget"); err != nil { - return fmt.Errorf("failed to create Budget FK constraint: %w", err) - } - } + // Create FK constraint for RateLimit (Budget FK is no longer needed - budgets use direct FK on budget table) if !migrator.HasConstraint(&tables.TableVirtualKeyProviderConfig{}, "RateLimit") { if err := migrator.CreateConstraint(&tables.TableVirtualKeyProviderConfig{}, "RateLimit"); err != nil { return fmt.Errorf("failed to create RateLimit FK constraint: %w", err) @@ -1020,32 +1064,19 @@ func migrationAddProviderConfigBudgetRateLimit(ctx context.Context, db *gorm.DB) tx = tx.WithContext(ctx) migrator := tx.Migrator() - // Drop indexes first - if err := tx.Exec("DROP INDEX IF EXISTS idx_provider_config_budget").Error; err != nil { - return fmt.Errorf("failed to drop budget_id index: %w", err) - } - if err := tx.Exec("DROP INDEX IF EXISTS idx_provider_config_rate_limit").Error; err != nil { - return fmt.Errorf("failed to drop rate_limit_id index: %w", err) - } + // Drop indexes + _ = tx.Exec("DROP INDEX IF EXISTS idx_provider_config_budget") + _ = tx.Exec("DROP INDEX IF EXISTS idx_provider_config_rate_limit") // Drop FK constraints - if migrator.HasConstraint(&tables.TableVirtualKeyProviderConfig{}, "Budget") { - if err := migrator.DropConstraint(&tables.TableVirtualKeyProviderConfig{}, "Budget"); err != nil { - return fmt.Errorf("failed to drop Budget FK constraint: %w", err) - } - } if migrator.HasConstraint(&tables.TableVirtualKeyProviderConfig{}, "RateLimit") { if err := migrator.DropConstraint(&tables.TableVirtualKeyProviderConfig{}, "RateLimit"); err != nil { return fmt.Errorf("failed to drop RateLimit FK constraint: %w", err) } } - // Drop columns - if migrator.HasColumn(&tables.TableVirtualKeyProviderConfig{}, "budget_id") { - if err := migrator.DropColumn(&tables.TableVirtualKeyProviderConfig{}, "budget_id"); err != nil { - return fmt.Errorf("failed to drop budget_id column: %w", err) - } - } + // Drop columns via raw SQL (budget_id no longer on struct) + _ = tx.Exec("ALTER TABLE governance_virtual_key_provider_configs DROP COLUMN IF EXISTS budget_id") if migrator.HasColumn(&tables.TableVirtualKeyProviderConfig{}, "rate_limit_id") { if err := migrator.DropColumn(&tables.TableVirtualKeyProviderConfig{}, "rate_limit_id"); err != nil { return fmt.Errorf("failed to drop rate_limit_id column: %w", err) @@ -1288,15 +1319,15 @@ func migrationAddVertexProjectNumberColumn(ctx context.Context, db *gorm.DB) err return nil } -// migrationAddVertexDeploymentsJSONColumn adds the vertex_deployments_json column to the key table +// migrationAddVertexDeploymentsJSONColumn adds the vertex_deployments_json column to the key table. +// This column is later dropped by migrationDropDeploymentColumnsAndAddAliases after data is migrated. func migrationAddVertexDeploymentsJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ ID: "add_vertex_deployments_json_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableKey{}, "vertex_deployments_json") { - if err := migrator.AddColumn(&tables.TableKey{}, "vertex_deployments_json"); err != nil { + if !tx.Migrator().HasColumn(&tables.TableKey{}, "vertex_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys ADD COLUMN vertex_deployments_json TEXT").Error; err != nil { return err } } @@ -1304,15 +1335,15 @@ func migrationAddVertexDeploymentsJSONColumn(ctx context.Context, db *gorm.DB) e }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if err := migrator.DropColumn(&tables.TableKey{}, "vertex_deployments_json"); err != nil { - return err + if tx.Migrator().HasColumn(&tables.TableKey{}, "vertex_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys DROP COLUMN vertex_deployments_json").Error; err != nil { + return err + } } return nil }, }}) - err := m.Migrate() - if err != nil { + if err := m.Migrate(); err != nil { return fmt.Errorf("error while running vertex deployments JSON migration: %s", err.Error()) } return nil @@ -2012,14 +2043,14 @@ func migrationAddConfigHashColumn(ctx context.Context, db *gorm.DB) error { if key.ConfigHash == "" { // Convert to schemas.Key and generate hash schemaKey := schemas.Key{ - Name: key.Name, - Value: key.Value, - Models: key.Models, - Weight: getWeight(key.Weight), - AzureKeyConfig: key.AzureKeyConfig, - VertexKeyConfig: key.VertexKeyConfig, - BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Name: key.Name, + Value: key.Value, + Models: key.Models, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, } hash, err := GenerateKeyHash(schemaKey) if err != nil { @@ -3488,15 +3519,15 @@ func migrationAddAzureScopesColumn(ctx context.Context, db *gorm.DB) error { return nil } -// migrationAddReplicateDeploymentsJSONColumn adds the replicate_deployments_json column to the key table +// migrationAddReplicateDeploymentsJSONColumn adds the replicate_deployments_json column to the key table. +// This column is later dropped by migrationDropDeploymentColumnsAndAddAliases after data is migrated. func migrationAddReplicateDeploymentsJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ ID: "add_replicate_deployments_json_column", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableKey{}, "replicate_deployments_json") { - if err := migrator.AddColumn(&tables.TableKey{}, "replicate_deployments_json"); err != nil { + if !tx.Migrator().HasColumn(&tables.TableKey{}, "replicate_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys ADD COLUMN replicate_deployments_json TEXT").Error; err != nil { return err } } @@ -3504,20 +3535,123 @@ func migrationAddReplicateDeploymentsJSONColumn(ctx context.Context, db *gorm.DB }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if err := migrator.DropColumn(&tables.TableKey{}, "replicate_deployments_json"); err != nil { - return err + if tx.Migrator().HasColumn(&tables.TableKey{}, "replicate_deployments_json") { + if err := tx.Exec("ALTER TABLE config_keys DROP COLUMN replicate_deployments_json").Error; err != nil { + return err + } } return nil }, }}) - err := m.Migrate() - if err != nil { + if err := m.Migrate(); err != nil { return fmt.Errorf("error while running replicate deployments JSON migration: %s", err.Error()) } return nil } +// migrationDropDeploymentColumnsAndAddAliases adds the unified aliases_json column, migrates +// existing per-provider deployment data into it, then drops the legacy columns. +// Only one deployment column will be populated per row (they were mutually exclusive). +func migrationDropDeploymentColumnsAndAddAliases(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "drop_deployment_columns_and_add_aliases", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + m := tx.Migrator() + + // Add aliases_json column first + if !m.HasColumn(&tables.TableKey{}, "aliases_json") { + if err := m.AddColumn(&tables.TableKey{}, "aliases_json"); err != nil { + return err + } + } + + // Copy data from whichever legacy deployment column is populated into aliases_json. + // Only rows where aliases_json is not already set are touched. + // Exactly one deployment column will be non-null per row (they were mutually exclusive). + for _, col := range []string{ + "azure_deployments_json", + "vertex_deployments_json", + "bedrock_deployments_json", + "replicate_deployments_json", + } { + if !m.HasColumn(&tables.TableKey{}, col) { + continue + } + if err := tx.Exec( + "UPDATE config_keys SET aliases_json = " + col + + " WHERE aliases_json IS NULL AND " + col + " IS NOT NULL AND " + col + " != ''", + ).Error; err != nil { + return err + } + } + + // Drop legacy deployment columns + for _, col := range []string{ + "azure_deployments_json", + "vertex_deployments_json", + "bedrock_deployments_json", + "replicate_deployments_json", + } { + if m.HasColumn(&tables.TableKey{}, col) { + if err := tx.Exec("ALTER TABLE config_keys DROP COLUMN " + col).Error; err != nil { + return err + } + } + } + + // Recompute config_hash for keys that had aliases_json populated above, + // since aliases_json is part of the hash input and these rows now have stale hashes. + var affectedKeys []tables.TableKey + if err := tx.Where( + "aliases_json IS NOT NULL AND aliases_json != ? AND aliases_json != ?", "", "{}", + ).Find(&affectedKeys).Error; err != nil { + return fmt.Errorf("failed to fetch keys for hash recomputation: %w", err) + } + for _, key := range affectedKeys { + schemaKey := schemas.Key{ + Name: key.Name, + Value: key.Value, + Models: key.Models, + BlacklistedModels: key.BlacklistedModels, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, + VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + Enabled: key.Enabled, + UseForBatchAPI: key.UseForBatchAPI, + } + hash, err := GenerateKeyHash(schemaKey) + if err != nil { + return fmt.Errorf("failed to generate hash for key %s: %w", key.Name, err) + } + if err := tx.Model(&key).Update("config_hash", hash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for key %s: %w", key.Name, err) + } + log.Printf("[Migration] Recomputed config_hash for key '%s' after aliases migration", key.Name) + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + m := tx.Migrator() + if m.HasColumn(&tables.TableKey{}, "aliases_json") { + if err := m.DropColumn(&tables.TableKey{}, "aliases_json"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running drop deployment columns and add aliases migration: %s", err.Error()) + } + return nil +} + // migrationAddKeyStatusColumns adds status and description columns to config_keys table // These columns track the status and description of each individual key func migrationAddKeyStatusColumns(ctx context.Context, db *gorm.DB) error { @@ -3708,6 +3842,126 @@ func migrationAddRateLimitToTeamsAndCustomers(ctx context.Context, db *gorm.DB) return nil } +// migrationBackfillEmptyVirtualKeyConfigs backfills existing virtual keys that have +// empty ProviderConfigs or MCPConfigs with all available providers/MCP clients. +// This preserves the previous "empty means all" behavior for existing VKs after +// the semantic change to "empty means none" (deny-by-default). +func migrationBackfillEmptyVirtualKeyConfigs(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "backfill_empty_virtual_key_configs", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + + // Step 1: Backfill ProviderConfigs for VKs that have none + // Find all virtual keys + var allVKs []tables.TableVirtualKey + if err := tx.Find(&allVKs).Error; err != nil { + return fmt.Errorf("failed to query virtual keys: %w", err) + } + + // Get all available providers + var allProviders []tables.TableProvider + if err := tx.Find(&allProviders).Error; err != nil { + return fmt.Errorf("failed to query providers: %w", err) + } + + // Track which VK IDs were modified so we can recompute their config_hash + modifiedVKIDs := make(map[string]struct{}) + + for _, vk := range allVKs { + // Check if this VK has any provider configs + var providerConfigCount int64 + if err := tx.Model(&tables.TableVirtualKeyProviderConfig{}).Where("virtual_key_id = ?", vk.ID).Count(&providerConfigCount).Error; err != nil { + return fmt.Errorf("failed to count provider configs for VK %s: %w", vk.ID, err) + } + + if providerConfigCount == 0 && len(allProviders) > 0 { + // VK has no provider configs - backfill with all available providers + for _, provider := range allProviders { + providerConfig := tables.TableVirtualKeyProviderConfig{ + VirtualKeyID: vk.ID, + Provider: provider.Name, + Weight: bifrost.Ptr(1.0), + AllowedModels: []string{}, + AllowAllKeys: true, + } + if err := tx.Create(&providerConfig).Error; err != nil { + return fmt.Errorf("failed to create provider config for VK %s, provider %s: %w", vk.ID, provider.Name, err) + } + } + modifiedVKIDs[vk.ID] = struct{}{} + log.Printf("[Migration] Backfilled VK '%s' with %d provider configs", vk.Name, len(allProviders)) + } + } + + // Step 2: Backfill MCPConfigs for VKs that have none + // Get all available MCP clients + var allMCPClients []tables.TableMCPClient + if err := tx.Find(&allMCPClients).Error; err != nil { + return fmt.Errorf("failed to query MCP clients: %w", err) + } + + for _, vk := range allVKs { + // Check if this VK has any MCP configs + var mcpConfigCount int64 + if err := tx.Model(&tables.TableVirtualKeyMCPConfig{}).Where("virtual_key_id = ?", vk.ID).Count(&mcpConfigCount).Error; err != nil { + return fmt.Errorf("failed to count MCP configs for VK %s: %w", vk.ID, err) + } + + if mcpConfigCount == 0 && len(allMCPClients) > 0 { + // VK has no MCP configs - backfill with all available MCP clients with wildcard + for _, mcpClient := range allMCPClients { + mcpConfig := tables.TableVirtualKeyMCPConfig{ + VirtualKeyID: vk.ID, + MCPClientID: mcpClient.ID, + ToolsToExecute: []string{"*"}, + } + if err := tx.Create(&mcpConfig).Error; err != nil { + return fmt.Errorf("failed to create MCP config for VK %s, client %d: %w", vk.ID, mcpClient.ID, err) + } + } + modifiedVKIDs[vk.ID] = struct{}{} + log.Printf("[Migration] Backfilled VK '%s' with %d MCP client configs", vk.Name, len(allMCPClients)) + } + } + + // Step 3: Recompute and persist config_hash for every VK that was modified. + // Without this, subsequent config-sync diff logic would see a stale hash and + // attempt to re-reconcile the VK (potentially undoing the backfill). + for vkID := range modifiedVKIDs { + var vk tables.TableVirtualKey + if err := tx. + Preload("ProviderConfigs"). + Preload("ProviderConfigs.Keys"). + Preload("MCPConfigs"). + First(&vk, "id = ?", vkID).Error; err != nil { + return fmt.Errorf("failed to reload VK %s for hash recomputation: %w", vkID, err) + } + newHash, err := GenerateVirtualKeyHash(vk) + if err != nil { + return fmt.Errorf("failed to generate hash for VK %s: %w", vkID, err) + } + if err := tx.Model(&tables.TableVirtualKey{}). + Where("id = ?", vkID). + Update("config_hash", newHash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for VK %s: %w", vkID, err) + } + log.Printf("[Migration] Recomputed config_hash for VK '%s'", vk.Name) + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + // No rollback needed - the backfilled configs are valid data + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running backfill empty virtual key configs migration: %s", err.Error()) + } + return nil +} + // migrationAddRequiredHeadersJSONColumn adds the required_headers_json column to the config_client table func migrationAddRequiredHeadersJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -3925,33 +4179,45 @@ func migrationAddEnforceAuthOnInferenceColumn(ctx context.Context, db *gorm.DB) return nil } -// migrationAddProviderPricingOverridesColumn adds the pricing_overrides_json column to the config_provider table -func migrationAddProviderPricingOverridesColumn(ctx context.Context, db *gorm.DB) error { +func migrationReconcilePricingOverridesTable(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ - ID: "add_provider_pricing_overrides_column", + ID: "reconcile_pricing_overrides_table", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TableProvider{}, "pricing_overrides_json") { - if err := migrator.AddColumn(&tables.TableProvider{}, "PricingOverridesJSON"); err != nil { - return fmt.Errorf("failed to add pricing_overrides_json column: %w", err) + mgr := tx.Migrator() + + if !mgr.HasTable(&tables.TablePricingOverride{}) { + if err := mgr.CreateTable(&tables.TablePricingOverride{}); err != nil { + return fmt.Errorf("failed to create governance_pricing_overrides table: %w", err) + } + return nil + } + if err := tx.AutoMigrate(&tables.TablePricingOverride{}); err != nil { + return fmt.Errorf("failed to automigrate governance_pricing_overrides table: %w", err) + } + for _, indexName := range []string{"idx_pricing_override_scope", "idx_pricing_override_match"} { + if mgr.HasIndex(&tables.TablePricingOverride{}, indexName) { + continue + } + if err := mgr.CreateIndex(&tables.TablePricingOverride{}, indexName); err != nil { + return fmt.Errorf("failed to create pricing override index %s: %w", indexName, err) } } return nil }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if migrator.HasColumn(&tables.TableProvider{}, "pricing_overrides_json") { - if err := migrator.DropColumn(&tables.TableProvider{}, "pricing_overrides_json"); err != nil { - return fmt.Errorf("failed to drop pricing_overrides_json column: %w", err) + mgr := tx.Migrator() + if mgr.HasTable(&tables.TablePricingOverride{}) { + if err := mgr.DropTable(&tables.TablePricingOverride{}); err != nil { + return fmt.Errorf("failed to drop governance_pricing_overrides table: %w", err) } } return nil }, }}) if err := m.Migrate(); err != nil { - return fmt.Errorf("error running provider pricing overrides column migration: %s", err.Error()) + return fmt.Errorf("error while running pricing overrides table reconcile migration: %s", err.Error()) } return nil } @@ -4231,6 +4497,124 @@ func migrationAddBedrockAssumeRoleColumns(ctx context.Context, db *gorm.DB) erro return nil } +// migrationAddAllowAllKeysToProviderConfig adds the allow_all_keys column to the provider config table +// and backfills existing rows: any provider config with no keys in the join table previously meant +// "allow all keys" (old semantic), so they get allow_all_keys = true to preserve behaviour. +func migrationAddAllowAllKeysToProviderConfig(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_allow_all_keys_to_provider_config", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + + // Add the column if it doesn't exist + if !migratorInstance.HasColumn(&tables.TableVirtualKeyProviderConfig{}, "allow_all_keys") { + if err := migratorInstance.AddColumn(&tables.TableVirtualKeyProviderConfig{}, "allow_all_keys"); err != nil { + return fmt.Errorf("failed to add allow_all_keys column: %w", err) + } + } + + // Backfill: find all provider configs that have no keys in the join table. + // These previously meant "allow all keys", so set allow_all_keys = true. + var allConfigs []tables.TableVirtualKeyProviderConfig + if err := tx.Find(&allConfigs).Error; err != nil { + return fmt.Errorf("failed to query provider configs: %w", err) + } + + // Track which VK IDs were modified so we can recompute their config_hash. + // Without this, subsequent config-sync diff logic would see a stale hash + // and attempt to re-reconcile the VK (potentially undoing the backfill). + modifiedVKIDs := make(map[string]struct{}) + + for _, pc := range allConfigs { + var keyCount int64 + if err := tx.Table("governance_virtual_key_provider_config_keys"). + Where("table_virtual_key_provider_config_id = ?", pc.ID). + Count(&keyCount).Error; err != nil { + return fmt.Errorf("failed to count keys for provider config %d: %w", pc.ID, err) + } + + if keyCount == 0 { + if err := tx.Model(&tables.TableVirtualKeyProviderConfig{}). + Where("id = ?", pc.ID). + Update("allow_all_keys", true).Error; err != nil { + return fmt.Errorf("failed to backfill allow_all_keys for provider config %d: %w", pc.ID, err) + } + modifiedVKIDs[pc.VirtualKeyID] = struct{}{} + } + } + + // Recompute and persist config_hash for every VK that was modified. + for vkID := range modifiedVKIDs { + var vk tables.TableVirtualKey + if err := tx. + Preload("ProviderConfigs"). + Preload("ProviderConfigs.Keys"). + Preload("MCPConfigs"). + First(&vk, "id = ?", vkID).Error; err != nil { + return fmt.Errorf("failed to reload VK %s for hash recomputation: %w", vkID, err) + } + newHash, err := GenerateVirtualKeyHash(vk) + if err != nil { + return fmt.Errorf("failed to generate hash for VK %s: %w", vkID, err) + } + if err := tx.Model(&tables.TableVirtualKey{}). + Where("id = ?", vkID). + Update("config_hash", newHash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for VK %s: %w", vkID, err) + } + log.Printf("[Migration] Recomputed config_hash for VK '%s'", vk.Name) + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if migratorInstance.HasColumn(&tables.TableVirtualKeyProviderConfig{}, "allow_all_keys") { + if err := migratorInstance.DropColumn(&tables.TableVirtualKeyProviderConfig{}, "allow_all_keys"); err != nil { + return fmt.Errorf("failed to drop allow_all_keys column: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running allow_all_keys migration: %s", err.Error()) + } + return nil +} + +// migrationAddMCPDisableAutoToolInjectColumn adds the mcp_disable_auto_tool_inject column to the client config table. +// When true, MCP tools are not automatically injected into requests; only explicit context filters apply. +func migrationAddMCPDisableAutoToolInjectColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_disable_auto_tool_inject_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if !migratorInstance.HasColumn(&tables.TableClientConfig{}, "mcp_disable_auto_tool_inject") { + if err := migratorInstance.AddColumn(&tables.TableClientConfig{}, "mcp_disable_auto_tool_inject"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migratorInstance := tx.Migrator() + if err := migratorInstance.DropColumn(&tables.TableClientConfig{}, "mcp_disable_auto_tool_inject"); err != nil { + return err + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running mcp disable auto tool inject migration: %s", err.Error()) + } + return nil +} + // migrationAddPricingRefactorColumns adds all new pricing columns introduced in the pricing module refactor func migrationAddPricingRefactorColumns(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -4739,50 +5123,245 @@ func migrationAddPromptRepoTables(ctx context.Context, db *gorm.DB) error { return nil } -// migrationAddPluginOrderColumns adds placement and exec_order columns to config_plugins table -func migrationAddPluginOrderColumns(ctx context.Context, db *gorm.DB) error { +// migrationBackfillAllowedModelsWildcard converts empty allowed_models on +// governance_virtual_key_provider_configs and empty models_json on keys to ["*"], +// preserving the previous "empty = allow all" semantics for existing records. +// After this migration the new convention applies: ["*"] = allow all, [] = deny all. +func migrationBackfillAllowedModelsWildcard(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ - ID: "add_plugin_order_columns", + ID: "backfill_allowed_models_wildcard", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - migrator := tx.Migrator() - if !migrator.HasColumn(&tables.TablePlugin{}, "placement") { - if err := migrator.AddColumn(&tables.TablePlugin{}, "Placement"); err != nil { - return fmt.Errorf("failed to add placement column: %w", err) - } + // --- Field 1: vk.provider_config.allowed_models --- + // Rows with '[]' previously meant "allow all models"; migrate to '["*"]'. + if err := tx.Model(&tables.TableVirtualKeyProviderConfig{}). + Where("allowed_models = ? OR allowed_models IS NULL", `[]`). + Update("allowed_models", `["*"]`).Error; err != nil { + return fmt.Errorf("failed to backfill provider_config allowed_models: %w", err) } - if !migrator.HasColumn(&tables.TablePlugin{}, "exec_order") { - if err := migrator.AddColumn(&tables.TablePlugin{}, "Order"); err != nil { - return fmt.Errorf("failed to add exec_order column: %w", err) + + // Recompute config_hash for all VKs that have provider configs + // (any of them may have had their allowed_models updated above). + var modifiedVKIDs []string + if err := tx.Model(&tables.TableVirtualKeyProviderConfig{}). + Distinct("virtual_key_id"). + Pluck("virtual_key_id", &modifiedVKIDs).Error; err != nil { + return fmt.Errorf("failed to query VK IDs for hash recomputation: %w", err) + } + + for _, vkID := range modifiedVKIDs { + var vk tables.TableVirtualKey + if err := tx. + Preload("ProviderConfigs"). + Preload("ProviderConfigs.Keys"). + Preload("MCPConfigs"). + First(&vk, "id = ?", vkID).Error; err != nil { + if err == gorm.ErrRecordNotFound { + // Orphaned provider config row β€” VK was deleted; skip. + continue + } + return fmt.Errorf("failed to reload VK %s for hash recomputation: %w", vkID, err) + } + newHash, err := GenerateVirtualKeyHash(vk) + if err != nil { + return fmt.Errorf("failed to generate hash for VK %s: %w", vkID, err) + } + if err := tx.Model(&tables.TableVirtualKey{}). + Where("id = ?", vkID). + Update("config_hash", newHash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for VK %s: %w", vkID, err) } + log.Printf("[Migration] Recomputed config_hash for VK '%s' after allowed_models backfill", vk.Name) } - return nil - }, - Rollback: func(tx *gorm.DB) error { - tx = tx.WithContext(ctx) - migrator := tx.Migrator() + // --- Field 2: provider.key.models (models_json column) --- + // Rows with '[]' or empty string previously meant "allow all models"; migrate to '["*"]'. + if err := tx.Model(&tables.TableKey{}). + Where("models_json = ? OR models_json = ? OR models_json IS NULL", `[]`, ``). + Update("models_json", `["*"]`).Error; err != nil { + return fmt.Errorf("failed to backfill key models_json: %w", err) + } - if migrator.HasColumn(&tables.TablePlugin{}, "placement") { - if err := migrator.DropColumn(&tables.TablePlugin{}, "placement"); err != nil { - return fmt.Errorf("failed to drop placement column: %w", err) - } + // Recompute config_hash for all keys since models_json is part of the hash input. + var keys []tables.TableKey + if err := tx.Find(&keys).Error; err != nil { + return fmt.Errorf("failed to fetch keys for hash recomputation: %w", err) } - if migrator.HasColumn(&tables.TablePlugin{}, "exec_order") { - if err := migrator.DropColumn(&tables.TablePlugin{}, "exec_order"); err != nil { - return fmt.Errorf("failed to drop exec_order column: %w", err) + for _, key := range keys { + schemaKey := schemas.Key{ + Name: key.Name, + Value: key.Value, + Models: key.Models, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, + VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + OllamaKeyConfig: key.OllamaKeyConfig, + SGLKeyConfig: key.SGLKeyConfig, + Enabled: key.Enabled, + UseForBatchAPI: key.UseForBatchAPI, + } + hash, err := GenerateKeyHash(schemaKey) + if err != nil { + return fmt.Errorf("failed to generate hash for key %s: %w", key.Name, err) + } + if err := tx.Model(&key).Update("config_hash", hash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for key %s: %w", key.Name, err) } } + return nil }, - }}) - if err := m.Migrate(); err != nil { + Rollback: func(tx *gorm.DB) error { + // Rollback is intentionally a no-op: reverting ["*"] back to [] would + // re-introduce the ambiguous "empty = allow all" semantics on downgrade. + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running backfill_allowed_models_wildcard migration: %s", err.Error()) + } + return nil +} + +// migrationAddMCPClientAllowedExtraHeadersJSONColumn adds the allowed_extra_headers_json column to the mcp_client table +func migrationAddMCPClientAllowedExtraHeadersJSONColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_client_allowed_extra_headers_json_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json") { + if err := migrator.DropColumn(&tables.TableMCPClient{}, "allowed_extra_headers_json"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running add_mcp_client_allowed_extra_headers_json_column migration: %s", err.Error()) + } + return nil +} + +// migrationAddPluginOrderColumns adds placement and exec_order columns to config_plugins table +func migrationAddPluginOrderColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_plugin_order_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + + if !migrator.HasColumn(&tables.TablePlugin{}, "placement") { + if err := migrator.AddColumn(&tables.TablePlugin{}, "Placement"); err != nil { + return fmt.Errorf("failed to add placement column: %w", err) + } + } + if !migrator.HasColumn(&tables.TablePlugin{}, "exec_order") { + if err := migrator.AddColumn(&tables.TablePlugin{}, "Order"); err != nil { + return fmt.Errorf("failed to add exec_order column: %w", err) + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + + if migrator.HasColumn(&tables.TablePlugin{}, "placement") { + if err := migrator.DropColumn(&tables.TablePlugin{}, "placement"); err != nil { + return fmt.Errorf("failed to drop placement column: %w", err) + } + } + if migrator.HasColumn(&tables.TablePlugin{}, "exec_order") { + if err := migrator.DropColumn(&tables.TablePlugin{}, "exec_order"); err != nil { + return fmt.Errorf("failed to drop exec_order column: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { return fmt.Errorf("error while running add_plugin_order_columns migration: %s", err.Error()) } return nil } +// migrationMakeBasePricingColumnsNullable drops the NOT NULL constraint on +// input_cost_per_token and output_cost_per_token in governance_model_pricing, +// allowing models that only have non-token pricing (image, audio, video) to be +// stored without a placeholder zero value. +func migrationMakeBasePricingColumnsNullable(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "make_base_pricing_columns_nullable", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + m := tx.Migrator() + if err := m.AlterColumn(&tables.TableModelPricing{}, "InputCostPerToken"); err != nil { + return fmt.Errorf("failed to alter input_cost_per_token: %w", err) + } + if err := m.AlterColumn(&tables.TableModelPricing{}, "OutputCostPerToken"); err != nil { + return fmt.Errorf("failed to alter output_cost_per_token: %w", err) + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running make_base_pricing_columns_nullable migration: %s", err.Error()) + } + return nil +} + +// migrationAddAllowOnAllVirtualKeysColumn adds the allow_on_all_virtual_keys column to the mcp_client table +func migrationAddAllowOnAllVirtualKeysColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_allow_on_all_virtual_keys_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "allow_on_all_virtual_keys") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "allow_on_all_virtual_keys"); err != nil { + return fmt.Errorf("failed to add allow_on_all_virtual_keys column: %w", err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableMCPClient{}, "allow_on_all_virtual_keys") { + if err := migrator.DropColumn(&tables.TableMCPClient{}, "allow_on_all_virtual_keys"); err != nil { + return fmt.Errorf("failed to drop allow_on_all_virtual_keys column: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running add_allow_on_all_virtual_keys_column migration: %s", err.Error()) + } + return nil +} + // migrationAddOpenAIConfigJSONColumn adds the open_ai_config_json column to the provider table func migrationAddOpenAIConfigJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -4849,16 +5428,36 @@ func migrationAddKeyBlacklistedModelsJSONColumn(ctx context.Context, db *gorm.DB return nil } -// migrationAddBudgetCalendarAlignedColumn adds the calendar_aligned column to the governance_budgets table. -func migrationAddBudgetCalendarAlignedColumn(ctx context.Context, db *gorm.DB) error { +// migrationAddChainRuleColumnToRoutingRules adds chain_rule to routing_rules. +// When true, the routing engine re-evaluates the full rule set after this rule matches, +// using the resolved provider/model as the new context input. +func migrationAddChainRuleColumnToRoutingRules(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ - ID: "add_budget_calendar_aligned_column", + ID: "add_chain_rule_column_to_routing_rules", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) mg := tx.Migrator() - if !mg.HasColumn(&tables.TableBudget{}, "calendar_aligned") { - if err := mg.AddColumn(&tables.TableBudget{}, "calendar_aligned"); err != nil { - return fmt.Errorf("failed to add calendar_aligned column: %w", err) + if !mg.HasColumn(&tables.TableRoutingRule{}, "chain_rule") { + if err := mg.AddColumn(&tables.TableRoutingRule{}, "chain_rule"); err != nil { + return fmt.Errorf("failed to add chain_rule column: %w", err) + } + } + + // Backfill config_hash for all existing routing rules. + // GenerateRoutingRuleHash now includes chain_rule, so existing hashes + // (computed without it) are stale and must be recomputed to avoid + // every rule appearing as changed after this upgrade. + var rules []tables.TableRoutingRule + if err := tx.Preload("Targets").Find(&rules).Error; err != nil { + return fmt.Errorf("failed to load routing rules for config_hash backfill: %w", err) + } + for _, rule := range rules { + hash, err := GenerateRoutingRuleHash(rule) + if err != nil { + return fmt.Errorf("failed to generate config_hash for routing rule %s: %w", rule.ID, err) + } + if err := tx.Model(&tables.TableRoutingRule{}).Where("id = ?", rule.ID).Update("config_hash", hash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for routing rule %s: %w", rule.ID, err) } } return nil @@ -4866,20 +5465,193 @@ func migrationAddBudgetCalendarAlignedColumn(ctx context.Context, db *gorm.DB) e Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) mg := tx.Migrator() - if mg.HasColumn(&tables.TableBudget{}, "calendar_aligned") { - if err := mg.DropColumn(&tables.TableBudget{}, "calendar_aligned"); err != nil { - return fmt.Errorf("failed to drop calendar_aligned column: %w", err) + if mg.HasColumn(&tables.TableRoutingRule{}, "chain_rule") { + if err := mg.DropColumn(&tables.TableRoutingRule{}, "chain_rule"); err != nil { + return fmt.Errorf("failed to drop chain_rule column: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_chain_rule_column_to_routing_rules migration: %s", err.Error()) + } + return nil +} + +// migrationAddReplicateKeyConfigColumn adds the replicate_use_deployments_endpoint column to the key table +func migrationAddReplicateKeyConfigColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_replicate_key_config_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint") { + if err := mg.AddColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint"); err != nil { + return err + } + // Backfill: Replicate keys that had deployments configured (now in aliases_json after + // migrationDropDeploymentColumnsAndAddAliases) were using the deployments endpoint. + trueVal := true + if err := tx.Model(&tables.TableKey{}). + Where("provider = ? AND aliases_json IS NOT NULL AND aliases_json != ? AND aliases_json != ?", + string(schemas.Replicate), "", "{}", + ). + Update("ReplicateUseDeploymentsEndpoint", &trueVal).Error; err != nil { + return err + } + + // Recompute config_hash for Replicate keys that were updated above, + // since replicate_use_deployments_endpoint is part of the hash input. + var affectedKeys []tables.TableKey + if err := tx.Where( + "provider = ? AND replicate_use_deployments_endpoint IS NOT NULL", + string(schemas.Replicate), + ).Find(&affectedKeys).Error; err != nil { + return fmt.Errorf("failed to fetch replicate keys for hash recomputation: %w", err) + } + for _, key := range affectedKeys { + schemaKey := schemas.Key{ + Name: key.Name, + Value: key.Value, + Models: key.Models, + BlacklistedModels: key.BlacklistedModels, + Weight: getWeight(key.Weight), + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, + VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + Enabled: key.Enabled, + UseForBatchAPI: key.UseForBatchAPI, + } + hash, err := GenerateKeyHash(schemaKey) + if err != nil { + return fmt.Errorf("failed to generate hash for key %s: %w", key.Name, err) + } + if err := tx.Model(&key).Update("config_hash", hash).Error; err != nil { + return fmt.Errorf("failed to update config_hash for key %s: %w", key.Name, err) + } + log.Printf("[Migration] Recomputed config_hash for replicate key '%s' after replicate config backfill", key.Name) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint") { + if err := mg.DropColumn(&tables.TableKey{}, "replicate_use_deployments_endpoint"); err != nil { + return err } } return nil }, }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_replicate_key_config_column migration: %s", err.Error()) + } + return nil +} + +// migrationAddBudgetCalendarAlignedColumn was originally for adding calendar_aligned to governance_budgets. +// Calendar alignment is now a VK-level field (governance_virtual_keys.calendar_aligned) added in migrationAddMultiBudgetTables. +// This migration is kept as a no-op so the migrator doesn't try to re-run it. +func migrationAddBudgetCalendarAlignedColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_budget_calendar_aligned_column", + Migrate: func(tx *gorm.DB) error { return nil }, + Rollback: func(tx *gorm.DB) error { return nil }, + }}) if err := m.Migrate(); err != nil { return fmt.Errorf("error running add_budget_calendar_aligned_column migration: %s", err.Error()) } return nil } +// migrationAddRoutingChainMaxDepthColumn adds routing_chain_max_depth to the client config table. +// Defaults to 10, which is the built-in default for routing rule chain evaluation depth. +func migrationAddRoutingChainMaxDepthColumn(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_routing_chain_max_depth_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasColumn(&tables.TableClientConfig{}, "routing_chain_max_depth") { + if err := mg.AddColumn(&tables.TableClientConfig{}, "routing_chain_max_depth"); err != nil { + return fmt.Errorf("failed to add routing_chain_max_depth column: %w", err) + } + // Recompute config_hash for all existing client configs that have one. + // RoutingChainMaxDepth is now included in the hash (when > 0), so without + // this recompute the stored hash would mismatch on every startup after upgrade. + var clientConfigs []tables.TableClientConfig + if err := tx.Find(&clientConfigs).Error; err != nil { + return fmt.Errorf("failed to fetch client configs for hash recompute: %w", err) + } + for _, cc := range clientConfigs { + if cc.ConfigHash == "" { + continue // no stored hash to invalidate + } + depth := cc.RoutingChainMaxDepth + if depth == 0 { + // Should never happen, but just in case. + depth = 10 // DefaultRoutingChainMaxDepth + } + clientConfig := ClientConfig{ + DropExcessRequests: cc.DropExcessRequests, + InitialPoolSize: cc.InitialPoolSize, + PrometheusLabels: cc.PrometheusLabels, + EnableLogging: cc.EnableLogging, + DisableContentLogging: cc.DisableContentLogging, + DisableDBPingsInHealth: cc.DisableDBPingsInHealth, + LogRetentionDays: cc.LogRetentionDays, + EnforceAuthOnInference: cc.EnforceAuthOnInference, + AllowDirectKeys: cc.AllowDirectKeys, + AllowedOrigins: cc.AllowedOrigins, + AllowedHeaders: cc.AllowedHeaders, + MaxRequestBodySizeMB: cc.MaxRequestBodySizeMB, + EnableLiteLLMFallbacks: cc.EnableLiteLLMFallbacks, + HideDeletedVirtualKeysInFilters: cc.HideDeletedVirtualKeysInFilters, + MCPAgentDepth: cc.MCPAgentDepth, + MCPToolExecutionTimeout: cc.MCPToolExecutionTimeout, + MCPCodeModeBindingLevel: cc.MCPCodeModeBindingLevel, + MCPToolSyncInterval: cc.MCPToolSyncInterval, + MCPDisableAutoToolInject: cc.MCPDisableAutoToolInject, + AsyncJobResultTTL: cc.AsyncJobResultTTL, + LoggingHeaders: cc.LoggingHeaders, + RequiredHeaders: cc.RequiredHeaders, + HeaderFilterConfig: cc.HeaderFilterConfig, + RoutingChainMaxDepth: depth, + } + newHash, err := clientConfig.GenerateClientConfigHash() + if err != nil { + return fmt.Errorf("failed to generate hash for client config %d: %w", cc.ID, err) + } + if err := tx.Model(&cc).Update("config_hash", newHash).Error; err != nil { + return fmt.Errorf("failed to update hash for client config %d: %w", cc.ID, err) + } + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasColumn(&tables.TableClientConfig{}, "routing_chain_max_depth") { + if err := mg.DropColumn(&tables.TableClientConfig{}, "routing_chain_max_depth"); err != nil { + return fmt.Errorf("failed to drop routing_chain_max_depth column: %w", err) + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_routing_chain_max_depth_column migration: %s", err.Error()) + } + return nil +} + // migrationAddModelCapabilityColumns adds model capability metadata columns to governance_model_pricing. func migrationAddModelCapabilityColumns(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -4927,6 +5699,357 @@ func migrationAddModelCapabilityColumns(ctx context.Context, db *gorm.DB) error return nil } +// migrationAddOllamaSGLConfigColumns adds ollama_url and sgl_url columns to the key table +func migrationAddOllamaSGLConfigColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_ollama_sgl_config_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableKey{}, "ollama_url") { + if err := migrator.AddColumn(&tables.TableKey{}, "ollama_url"); err != nil { + return err + } + } + if !migrator.HasColumn(&tables.TableKey{}, "sgl_url") { + if err := migrator.AddColumn(&tables.TableKey{}, "sgl_url"); err != nil { + return err + } + } + + // Backfill: for each ollama/sgl provider with a base_url, create a key + // with that URL and clear base_url from network_config. + var providers []tables.TableProvider + if err := tx.Where("name IN ?", []string{"ollama", "sgl"}).Find(&providers).Error; err != nil { + return fmt.Errorf("failed to fetch ollama/sgl providers for URL backfill: %w", err) + } + for _, p := range providers { + if p.NetworkConfigJSON == "" { + continue + } + var nc schemas.NetworkConfig + if err := json.Unmarshal([]byte(p.NetworkConfigJSON), &nc); err != nil { + log.Printf("[Migration] Failed to parse network_config for provider %s (id=%d), skipping: %v", p.Name, p.ID, err) + continue + } + if nc.BaseURL == "" { + continue + } + + // Create a new key with the provider's base_url + urlEnvVar := schemas.EnvVar{Val: nc.BaseURL} + enabled := true + weight := 1.0 + newKey := tables.TableKey{ + Provider: p.Name, + ProviderID: p.ID, + KeyID: uuid.NewString(), + Weight: &weight, + Enabled: &enabled, + Models: schemas.WhiteList{"*"}, + } + if strings.ToLower(p.Name) == "ollama" { + newKey.Name = "Default Ollama Key" + newKey.OllamaKeyConfig = &schemas.OllamaKeyConfig{URL: urlEnvVar} + } + if strings.ToLower(p.Name) == "sgl" { + newKey.Name = "Default SGL Key" + newKey.SGLKeyConfig = &schemas.SGLKeyConfig{URL: urlEnvVar} + } + + schemaKey := schemaKeyFromTableKey(newKey) + hash, err := GenerateKeyHash(schemaKey) + if err != nil { + return fmt.Errorf("failed to generate hash for new key on provider %s: %w", p.Name, err) + } + newKey.ConfigHash = hash + if err := tx.Create(&newKey).Error; err != nil { + return fmt.Errorf("failed to create key for provider %s: %w", p.Name, err) + } + + log.Printf("[Migration] Created key '%s' for provider '%s' from network_config.base_url", newKey.Name, p.Name) + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableKey{}, "ollama_url") { + if err := migrator.DropColumn(&tables.TableKey{}, "ollama_url"); err != nil { + return err + } + } + if migrator.HasColumn(&tables.TableKey{}, "sgl_url") { + if err := migrator.DropColumn(&tables.TableKey{}, "sgl_url"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running ollama sgl key config columns migration: %s", err.Error()) + } + return nil +} + +// migrationAddMultiBudgetTables creates junction tables for multi-budget support and backfills existing data. +func migrationAddMultiBudgetTables(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_multi_budget_tables", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + + // Add calendar_aligned to governance_virtual_keys (VK-level setting) + if !mg.HasColumn(&tables.TableVirtualKey{}, "calendar_aligned") { + if err := mg.AddColumn(&tables.TableVirtualKey{}, "CalendarAligned"); err != nil { + return fmt.Errorf("failed to add calendar_aligned column to governance_virtual_keys: %w", err) + } + } + + // Add FK columns on governance_budgets for multi-budget ownership + if !mg.HasColumn(&tables.TableBudget{}, "virtual_key_id") { + if err := mg.AddColumn(&tables.TableBudget{}, "VirtualKeyID"); err != nil { + return fmt.Errorf("failed to add virtual_key_id column to governance_budgets: %w", err) + } + } + if !mg.HasColumn(&tables.TableBudget{}, "provider_config_id") { + if err := mg.AddColumn(&tables.TableBudget{}, "ProviderConfigID"); err != nil { + return fmt.Errorf("failed to add provider_config_id column to governance_budgets: %w", err) + } + } + + // Create indexes on the new FK columns (AddColumn doesn't create indexes from struct tags) + if !mg.HasIndex(&tables.TableBudget{}, "idx_governance_budgets_virtual_key_id") { + if err := mg.CreateIndex(&tables.TableBudget{}, "VirtualKeyID"); err != nil { + return fmt.Errorf("failed to create index on governance_budgets.virtual_key_id: %w", err) + } + } + if !mg.HasIndex(&tables.TableBudget{}, "idx_governance_budgets_provider_config_id") { + if err := mg.CreateIndex(&tables.TableBudget{}, "ProviderConfigID"); err != nil { + return fmt.Errorf("failed to create index on governance_budgets.provider_config_id: %w", err) + } + } + + // Create FK constraints with CASCADE delete (defined on parent structs) + if !mg.HasConstraint(&tables.TableVirtualKey{}, "Budgets") { + if err := mg.CreateConstraint(&tables.TableVirtualKey{}, "Budgets"); err != nil { + return fmt.Errorf("failed to create FK constraint for VirtualKey -> Budgets: %w", err) + } + } + if !mg.HasConstraint(&tables.TableVirtualKeyProviderConfig{}, "Budgets") { + if err := mg.CreateConstraint(&tables.TableVirtualKeyProviderConfig{}, "Budgets"); err != nil { + return fmt.Errorf("failed to create FK constraint for ProviderConfig -> Budgets: %w", err) + } + } + + // Backfill: set virtual_key_id from legacy VK budget_id (if column still exists) + if mg.HasColumn(&tables.TableVirtualKey{}, "budget_id") { + if err := tx.Exec(` + UPDATE governance_budgets SET virtual_key_id = ( + SELECT id FROM governance_virtual_keys + WHERE governance_virtual_keys.budget_id = governance_budgets.id + ) WHERE virtual_key_id IS NULL AND EXISTS ( + SELECT 1 FROM governance_virtual_keys + WHERE governance_virtual_keys.budget_id = governance_budgets.id + ) + `).Error; err != nil { + return fmt.Errorf("failed to backfill VK budget virtual_key_id: %w", err) + } + } + + // Backfill: set provider_config_id from legacy PC budget_id (if column still exists) + if mg.HasColumn(&tables.TableVirtualKeyProviderConfig{}, "budget_id") { + if err := tx.Exec(` + UPDATE governance_budgets SET provider_config_id = ( + SELECT id FROM governance_virtual_key_provider_configs + WHERE governance_virtual_key_provider_configs.budget_id = governance_budgets.id + ) WHERE provider_config_id IS NULL AND EXISTS ( + SELECT 1 FROM governance_virtual_key_provider_configs + WHERE governance_virtual_key_provider_configs.budget_id = governance_budgets.id + ) + `).Error; err != nil { + return fmt.Errorf("failed to backfill PC budget provider_config_id: %w", err) + } + } + + // Backfill: copy calendar_aligned from legacy budget column to VK-level field + // (governance_budgets.calendar_aligned was added by add_budget_calendar_aligned_column on main) + if mg.HasColumn(&tables.TableBudget{}, "calendar_aligned") { + if err := tx.Exec(` + UPDATE governance_virtual_keys SET calendar_aligned = true + WHERE id IN ( + SELECT DISTINCT virtual_key_id FROM governance_budgets + WHERE calendar_aligned = true AND virtual_key_id IS NOT NULL + ) AND calendar_aligned = false + `).Error; err != nil { + return fmt.Errorf("failed to backfill calendar_aligned from budgets to virtual keys: %w", err) + } + // Drop the legacy calendar_aligned column from governance_budgets + _ = tx.Exec("ALTER TABLE governance_budgets DROP COLUMN IF EXISTS calendar_aligned") + } + + // Drop legacy budget_id columns from VK and ProviderConfig (raw SQL to avoid GORM FK lookup issues) + _ = tx.Exec("ALTER TABLE governance_virtual_keys DROP COLUMN IF EXISTS budget_id") + _ = tx.Exec("ALTER TABLE governance_virtual_key_provider_configs DROP COLUMN IF EXISTS budget_id") + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if mg.HasColumn(&tables.TableBudget{}, "virtual_key_id") { + if err := mg.DropColumn(&tables.TableBudget{}, "virtual_key_id"); err != nil { + return err + } + } + if mg.HasColumn(&tables.TableBudget{}, "provider_config_id") { + if err := mg.DropColumn(&tables.TableBudget{}, "provider_config_id"); err != nil { + return err + + } + } + return nil + }, + }}) + // SQLite workaround: GORM's CreateConstraint rebuilds the table via DROP+RENAME + // inside a transaction. The DROP fails when other tables have FKs pointing at the + // target table and foreign_keys is ON. PRAGMA foreign_keys cannot be changed inside + // a transaction, so we disable it before the migrator opens its transaction. + // This only affects SQLite β€” Postgres supports ALTER TABLE ADD CONSTRAINT natively. + if db.Dialector.Name() == "sqlite" { + // PRAGMA foreign_keys is per-connection in SQLite. Pin the pool to a single + // connection so the PRAGMA and the migration transaction share the same one. + sqlDB, err := db.DB() + if err != nil { + return fmt.Errorf("failed to get underlying sql.DB: %w", err) + } + sqlDB.SetMaxOpenConns(1) + defer sqlDB.SetMaxOpenConns(0) // restore default + + if err := db.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { + return fmt.Errorf("failed to disable SQLite foreign keys: %w", err) + } + defer func() { + if err := db.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + log.Fatalf("[Migration] FATAL: failed to re-enable SQLite foreign keys: %v", err) + } + }() + } + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_multi_budget_tables migration: %s", err.Error()) + } + return nil +} + +// migrationAddPerUserOAuthTables adds the oauth_user_sessions and oauth_user_tokens tables +func migrationAddPerUserOAuthTables(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_per_user_oauth_tables", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + if !mg.HasTable(&tables.TablePerUserOAuthClient{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthClient{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_clients table: %w", err) + } + } + if !mg.HasTable(&tables.TablePerUserOAuthSession{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthSession{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_sessions table: %w", err) + } + } + if !mg.HasTable(&tables.TablePerUserOAuthCode{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthCode{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_codes table: %w", err) + } + } + if !mg.HasTable(&tables.TableOauthUserToken{}) { + if err := mg.CreateTable(&tables.TableOauthUserToken{}); err != nil { + return fmt.Errorf("failed to create oauth_user_tokens table: %w", err) + } + } + if !mg.HasTable(&tables.TableOauthUserSession{}) { + if err := mg.CreateTable(&tables.TableOauthUserSession{}); err != nil { + return fmt.Errorf("failed to create oauth_user_sessions table: %w", err) + } + } + if !mg.HasTable(&tables.TablePerUserOAuthPendingFlow{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthPendingFlow{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_pending_flows table: %w", err) + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mg := tx.Migrator() + for _, table := range []any{ + &tables.TablePerUserOAuthPendingFlow{}, + &tables.TablePerUserOAuthCode{}, + &tables.TablePerUserOAuthSession{}, + &tables.TablePerUserOAuthClient{}, + &tables.TableOauthUserToken{}, + &tables.TableOauthUserSession{}, + } { + if mg.HasTable(table) { + if err := mg.DropTable(table); err != nil { + return err + } + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_per_user_oauth_tables migration: %s", err.Error()) + } + return nil +} + +// migrationAddMCPClientDiscoveredToolsColumns adds discovered_tools_json and tool_name_mapping_json columns to the mcp_client table +func migrationAddMCPClientDiscoveredToolsColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_mcp_client_discovered_tools_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&tables.TableMCPClient{}, "discovered_tools_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "discovered_tools_json"); err != nil { + return err + } + } + if !migrator.HasColumn(&tables.TableMCPClient{}, "tool_name_mapping_json") { + if err := migrator.AddColumn(&tables.TableMCPClient{}, "tool_name_mapping_json"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TableMCPClient{}, "discovered_tools_json") { + if err := migrator.DropColumn(&tables.TableMCPClient{}, "discovered_tools_json"); err != nil { + return err + } + } + if migrator.HasColumn(&tables.TableMCPClient{}, "tool_name_mapping_json") { + if err := migrator.DropColumn(&tables.TableMCPClient{}, "tool_name_mapping_json"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error running add_mcp_client_discovered_tools_columns migration: %s", err.Error()) + } + return nil +} + // migrationAddWhitelistedRoutesJSONColumn adds the whitelisted_routes_json column to the config_client table func migrationAddWhitelistedRoutesJSONColumn(ctx context.Context, db *gorm.DB) error { m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ @@ -4940,24 +6063,21 @@ func migrationAddWhitelistedRoutesJSONColumn(ctx context.Context, db *gorm.DB) e return fmt.Errorf("failed to add whitelisted_routes_json column: %w", err) } } - return nil }, Rollback: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) migrator := tx.Migrator() - if migrator.HasColumn(&tables.TableClientConfig{}, "whitelisted_routes_json") { if err := migrator.DropColumn(&tables.TableClientConfig{}, "whitelisted_routes_json"); err != nil { return fmt.Errorf("failed to drop whitelisted_routes_json column: %w", err) } } - return nil }, }}) if err := m.Migrate(); err != nil { - return fmt.Errorf("error running whitelisted_routes_json migration: %s", err.Error()) + return fmt.Errorf("error running add_whitelisted_routes_json_column migration: %s", err.Error()) } return nil } diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go index a57f7262cb..df455333d6 100644 --- a/framework/configstore/migrations_test.go +++ b/framework/configstore/migrations_test.go @@ -582,13 +582,13 @@ func setupProviderTestDBWithoutStoreRawColumn(t *testing.T) *gorm.DB { `).Error require.NoError(t, err, "Failed to create config_providers table") - // Create the gomigrate table for the migrator + // Create the migrations table for the migrator (matches migrator.DefaultOptions.TableName) err = db.Exec(` - CREATE TABLE IF NOT EXISTS gomigrate ( + CREATE TABLE IF NOT EXISTS migrations ( id VARCHAR(255) PRIMARY KEY ) `).Error - require.NoError(t, err, "Failed to create gomigrate table") + require.NoError(t, err, "Failed to create migrations table") return db } @@ -614,9 +614,18 @@ func trySetupPostgresDBWithoutStoreRawColumn(t *testing.T, testSuffix string) *g return nil } - // Drop the table if it exists to start fresh (for this specific test) - db.Exec("DROP TABLE IF EXISTS gomigrate") - db.Exec("DROP TABLE IF EXISTS config_providers") + // Drop config_providers to start fresh (for this specific test). + // Use CASCADE to drop dependent objects (composite types, sequences, etc.). + db.Exec("DROP TABLE IF EXISTS config_providers CASCADE") + + // Clear migration tracking without dropping the table β€” other test packages + // (e.g. logstore) may share this Postgres instance and use the same table + // concurrently. CREATE IF NOT EXISTS is safe even if the table already + // exists from a previous test or a concurrent package. + db.Exec(`CREATE TABLE IF NOT EXISTS migrations ( + id VARCHAR(255) PRIMARY KEY + )`) + db.Exec("DELETE FROM migrations") // Create the config_providers table manually without store_raw_request_response column // This simulates the pre-migration state (PostgreSQL syntax) @@ -645,20 +654,11 @@ func trySetupPostgresDBWithoutStoreRawColumn(t *testing.T, testSuffix string) *g return nil } - // Create the gomigrate table for the migrator - err = db.Exec(` - CREATE TABLE IF NOT EXISTS gomigrate ( - id VARCHAR(255) PRIMARY KEY - ) - `).Error - if err != nil { - return nil - } - - // Clean up tables after the test + // Clean up after the test β€” drop config_providers but leave migrations + // intact for concurrent test packages. t.Cleanup(func() { - db.Exec("DROP TABLE IF EXISTS gomigrate") - db.Exec("DROP TABLE IF EXISTS config_providers") + db.Exec("DELETE FROM migrations") + db.Exec("DROP TABLE IF EXISTS config_providers CASCADE") }) return db diff --git a/framework/configstore/prompts.go b/framework/configstore/prompts.go index 18b3638bb8..e760351b95 100644 --- a/framework/configstore/prompts.go +++ b/framework/configstore/prompts.go @@ -30,9 +30,6 @@ func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, if err := s.db.WithContext(ctx). Order("created_at DESC"). Find(&folders).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TableFolder{}, nil - } return nil, err } @@ -147,9 +144,6 @@ func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]ta } if err := query.Find(&prompts).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TablePrompt{}, nil - } return nil, err } @@ -261,6 +255,18 @@ func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { // Prompt Repository - Versions // ============================================================================ +// GetAllPromptVersions returns every version across all prompts in a single query. +func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { + var versions []tables.TablePromptVersion + if err := s.db.WithContext(ctx). + Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). + Order("prompt_id ASC, version_number DESC"). + Find(&versions).Error; err != nil { + return nil, err + } + return versions, nil +} + // GetPromptVersions gets all versions for a prompt func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { var versions []tables.TablePromptVersion @@ -269,9 +275,6 @@ func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) Where("prompt_id = ?", promptID). Order("version_number DESC"). Find(&versions).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TablePromptVersion{}, nil - } return nil, err } return versions, nil @@ -416,9 +419,6 @@ func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) Where("prompt_id = ?", promptID). Order("created_at DESC"). Find(&sessions).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TablePromptSession{}, nil - } return nil, err } return sessions, nil diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 343e502cd2..624ab27300 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -35,6 +35,91 @@ func getWeight(w *float64) float64 { return *w } +// schemaKeyFromTableKey converts a database key to a schema key. +func schemaKeyFromTableKey(dbKey tables.TableKey) schemas.Key { + return schemas.Key{ + ID: dbKey.KeyID, + Name: dbKey.Name, + Value: dbKey.Value, + Models: dbKey.Models, + BlacklistedModels: dbKey.BlacklistedModels, + Weight: getWeight(dbKey.Weight), + Enabled: dbKey.Enabled, + UseForBatchAPI: dbKey.UseForBatchAPI, + AzureKeyConfig: dbKey.AzureKeyConfig, + VertexKeyConfig: dbKey.VertexKeyConfig, + BedrockKeyConfig: dbKey.BedrockKeyConfig, + Aliases: dbKey.Aliases, + VLLMKeyConfig: dbKey.VLLMKeyConfig, + ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + OllamaKeyConfig: dbKey.OllamaKeyConfig, + SGLKeyConfig: dbKey.SGLKeyConfig, + ConfigHash: dbKey.ConfigHash, + Status: schemas.KeyStatusType(dbKey.Status), + Description: dbKey.Description, + } +} + +// tableKeyFromSchemaKey converts a schema key to a database key. +func tableKeyFromSchemaKey(provider tables.TableProvider, key schemas.Key) (tables.TableKey, error) { + dbKey := tables.TableKey{ + Provider: provider.Name, + ProviderID: provider.ID, + KeyID: key.ID, + Name: key.Name, + Value: key.Value, + Models: key.Models, + BlacklistedModels: key.BlacklistedModels, + Weight: &key.Weight, + Enabled: key.Enabled, + UseForBatchAPI: key.UseForBatchAPI, + AzureKeyConfig: key.AzureKeyConfig, + VertexKeyConfig: key.VertexKeyConfig, + BedrockKeyConfig: key.BedrockKeyConfig, + Aliases: key.Aliases, + VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + OllamaKeyConfig: key.OllamaKeyConfig, + SGLKeyConfig: key.SGLKeyConfig, + ConfigHash: key.ConfigHash, + Status: string(key.Status), + Description: key.Description, + } + + if key.AzureKeyConfig != nil { + dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint + dbKey.AzureAPIVersion = key.AzureKeyConfig.APIVersion + } + + if key.VertexKeyConfig != nil { + dbKey.VertexProjectID = &key.VertexKeyConfig.ProjectID + dbKey.VertexProjectNumber = &key.VertexKeyConfig.ProjectNumber + dbKey.VertexRegion = &key.VertexKeyConfig.Region + dbKey.VertexAuthCredentials = &key.VertexKeyConfig.AuthCredentials + } + + if key.BedrockKeyConfig != nil { + dbKey.BedrockAccessKey = &key.BedrockKeyConfig.AccessKey + dbKey.BedrockSecretKey = &key.BedrockKeyConfig.SecretKey + dbKey.BedrockSessionToken = key.BedrockKeyConfig.SessionToken + dbKey.BedrockRegion = key.BedrockKeyConfig.Region + dbKey.BedrockARN = key.BedrockKeyConfig.ARN + dbKey.BedrockRoleARN = key.BedrockKeyConfig.RoleARN + dbKey.BedrockExternalID = key.BedrockKeyConfig.ExternalID + dbKey.BedrockRoleSessionName = key.BedrockKeyConfig.RoleSessionName + if key.BedrockKeyConfig.BatchS3Config != nil { + data, err := sonic.Marshal(key.BedrockKeyConfig.BatchS3Config) + if err != nil { + return tables.TableKey{}, err + } + s := string(data) + dbKey.BedrockBatchS3ConfigJSON = &s + } + } + + return dbKey, nil +} + // UpdateClientConfig updates the client configuration in the database. func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientConfig) error { dbConfig := tables.TableClientConfig{ @@ -57,11 +142,13 @@ func (s *RDBConfigStore) UpdateClientConfig(ctx context.Context, config *ClientC MCPToolExecutionTimeout: config.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: config.MCPCodeModeBindingLevel, MCPToolSyncInterval: config.MCPToolSyncInterval, + MCPDisableAutoToolInject: config.MCPDisableAutoToolInject, AsyncJobResultTTL: config.AsyncJobResultTTL, RequiredHeaders: config.RequiredHeaders, LoggingHeaders: config.LoggingHeaders, WhitelistedRoutes: config.WhitelistedRoutes, HideDeletedVirtualKeysInFilters: config.HideDeletedVirtualKeysInFilters, + RoutingChainMaxDepth: config.RoutingChainMaxDepth, HeaderFilterConfig: config.HeaderFilterConfig, ConfigHash: config.ConfigHash, } @@ -91,13 +178,10 @@ func (s *RDBConfigStore) parseGormError(err error) error { if err == nil { return nil } - if errors.Is(err, gorm.ErrRecordNotFound) { return ErrNotFound } - errMsg := err.Error() - // Check for unique constraint violations // SQLite format: "UNIQUE constraint failed: table_name.column_name" // PostgreSQL format: "ERROR: duplicate key value violates unique constraint" @@ -224,11 +308,13 @@ func (s *RDBConfigStore) GetClientConfig(ctx context.Context) (*ClientConfig, er MCPToolExecutionTimeout: dbConfig.MCPToolExecutionTimeout, MCPCodeModeBindingLevel: dbConfig.MCPCodeModeBindingLevel, MCPToolSyncInterval: dbConfig.MCPToolSyncInterval, + MCPDisableAutoToolInject: dbConfig.MCPDisableAutoToolInject, AsyncJobResultTTL: dbConfig.AsyncJobResultTTL, RequiredHeaders: dbConfig.RequiredHeaders, LoggingHeaders: dbConfig.LoggingHeaders, WhitelistedRoutes: dbConfig.WhitelistedRoutes, HideDeletedVirtualKeysInFilters: dbConfig.HideDeletedVirtualKeysInFilters, + RoutingChainMaxDepth: dbConfig.RoutingChainMaxDepth, HeaderFilterConfig: dbConfig.HeaderFilterConfig, ConfigHash: dbConfig.ConfigHash, }, nil @@ -253,7 +339,6 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma StoreRawRequestResponse: providerConfig.StoreRawRequestResponse, CustomProviderConfig: providerConfig.CustomProviderConfig, OpenAIConfig: providerConfig.OpenAIConfig, - PricingOverrides: providerConfig.PricingOverrides, ConfigHash: providerConfig.ConfigHash, Status: providerConfig.Status, Description: providerConfig.Description, @@ -297,8 +382,11 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + OllamaKeyConfig: key.OllamaKeyConfig, + SGLKeyConfig: key.SGLKeyConfig, ConfigHash: keyHash, Status: string(key.Status), Description: key.Description, @@ -357,6 +445,7 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma dbKey.Status = existingKey.Status // Preserve status (UI-managed) dbKey.Description = existingKey.Description // Preserve description (UI-managed) dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status + dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } @@ -372,6 +461,7 @@ func (s *RDBConfigStore) UpdateProvidersConfig(ctx context.Context, providers ma dbKey.Status = existingKey.Status // Preserve status (UI-managed) dbKey.Description = existingKey.Description // Preserve description (UI-managed) dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status + dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } @@ -426,7 +516,6 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo dbProvider.StoreRawRequestResponse = configCopy.StoreRawRequestResponse dbProvider.CustomProviderConfig = configCopy.CustomProviderConfig dbProvider.OpenAIConfig = configCopy.OpenAIConfig - dbProvider.PricingOverrides = configCopy.PricingOverrides dbProvider.ConfigHash = configCopy.ConfigHash // Save the updated provider @@ -467,8 +556,11 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + OllamaKeyConfig: key.OllamaKeyConfig, + SGLKeyConfig: key.SGLKeyConfig, ConfigHash: keyHash, Status: string(key.Status), Description: key.Description, @@ -517,6 +609,7 @@ func (s *RDBConfigStore) UpdateProvider(ctx context.Context, provider schemas.Mo dbKey.Status = existingKey.Status // Preserve status (UI-managed) dbKey.Description = existingKey.Description // Preserve description (UI-managed) dbKey.EncryptionStatus = existingKey.EncryptionStatus // Preserve encryption status + dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { return s.parseGormError(err) } @@ -567,7 +660,6 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model StoreRawRequestResponse: configCopy.StoreRawRequestResponse, CustomProviderConfig: configCopy.CustomProviderConfig, OpenAIConfig: configCopy.OpenAIConfig, - PricingOverrides: configCopy.PricingOverrides, ConfigHash: configCopy.ConfigHash, } // Create the provider @@ -590,8 +682,11 @@ func (s *RDBConfigStore) AddProvider(ctx context.Context, provider schemas.Model AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, - ReplicateKeyConfig: key.ReplicateKeyConfig, + Aliases: key.Aliases, VLLMKeyConfig: key.VLLMKeyConfig, + ReplicateKeyConfig: key.ReplicateKeyConfig, + OllamaKeyConfig: key.OllamaKeyConfig, + SGLKeyConfig: key.SGLKeyConfig, ConfigHash: key.ConfigHash, Status: string(key.Status), Description: key.Description, @@ -700,24 +795,7 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo // Convert database keys to schemas.Key keys := make([]schemas.Key, len(dbProvider.Keys)) for i, dbKey := range dbProvider.Keys { - keys[i] = schemas.Key{ - ID: dbKey.KeyID, - Name: dbKey.Name, - Value: dbKey.Value, - Models: dbKey.Models, - BlacklistedModels: dbKey.BlacklistedModels, - Weight: getWeight(dbKey.Weight), - Enabled: dbKey.Enabled, - UseForBatchAPI: dbKey.UseForBatchAPI, - AzureKeyConfig: dbKey.AzureKeyConfig, - VertexKeyConfig: dbKey.VertexKeyConfig, - BedrockKeyConfig: dbKey.BedrockKeyConfig, - ReplicateKeyConfig: dbKey.ReplicateKeyConfig, - VLLMKeyConfig: dbKey.VLLMKeyConfig, - ConfigHash: dbKey.ConfigHash, - Status: schemas.KeyStatusType(dbKey.Status), - Description: dbKey.Description, - } + keys[i] = schemaKeyFromTableKey(dbKey) } providerConfig := ProviderConfig{ Keys: keys, @@ -729,7 +807,6 @@ func (s *RDBConfigStore) GetProvidersConfig(ctx context.Context) (map[schemas.Mo StoreRawRequestResponse: dbProvider.StoreRawRequestResponse, CustomProviderConfig: dbProvider.CustomProviderConfig, OpenAIConfig: dbProvider.OpenAIConfig, - PricingOverrides: dbProvider.PricingOverrides, ConfigHash: dbProvider.ConfigHash, Status: dbProvider.Status, Description: dbProvider.Description, @@ -751,24 +828,7 @@ func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas keys := make([]schemas.Key, len(dbProvider.Keys)) for i, dbKey := range dbProvider.Keys { - keys[i] = schemas.Key{ - ID: dbKey.KeyID, - Name: dbKey.Name, - Value: dbKey.Value, - Models: dbKey.Models, - BlacklistedModels: dbKey.BlacklistedModels, - Weight: getWeight(dbKey.Weight), - Enabled: dbKey.Enabled, - UseForBatchAPI: dbKey.UseForBatchAPI, - AzureKeyConfig: dbKey.AzureKeyConfig, - VertexKeyConfig: dbKey.VertexKeyConfig, - BedrockKeyConfig: dbKey.BedrockKeyConfig, - ReplicateKeyConfig: dbKey.ReplicateKeyConfig, - VLLMKeyConfig: dbKey.VLLMKeyConfig, - ConfigHash: dbKey.ConfigHash, - Status: schemas.KeyStatusType(dbKey.Status), - Description: dbKey.Description, - } + keys[i] = schemaKeyFromTableKey(dbKey) } return &ProviderConfig{ Keys: keys, @@ -780,13 +840,160 @@ func (s *RDBConfigStore) GetProviderConfig(ctx context.Context, provider schemas StoreRawRequestResponse: dbProvider.StoreRawRequestResponse, CustomProviderConfig: dbProvider.CustomProviderConfig, OpenAIConfig: dbProvider.OpenAIConfig, - PricingOverrides: dbProvider.PricingOverrides, ConfigHash: dbProvider.ConfigHash, Status: dbProvider.Status, Description: dbProvider.Description, }, nil } +// GetProviderKeys retrieves all keys for a provider ordered by creation time. +func (s *RDBConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + var dbKeys []tables.TableKey + result := s.db.WithContext(ctx). + Table("config_providers"). + Select("config_keys.*"). + Joins("LEFT JOIN config_keys ON config_keys.provider_id = config_providers.id"). + Where("config_providers.name = ?", string(provider)). + Order("config_keys.created_at ASC"). + Scan(&dbKeys) + if result.Error != nil { + return nil, result.Error + } + if result.RowsAffected == 0 { + return nil, ErrNotFound + } + if len(dbKeys) == 1 && dbKeys[0].ID == 0 && dbKeys[0].KeyID == "" { + return []schemas.Key{}, nil + } + + keys := make([]schemas.Key, 0, len(dbKeys)) + for _, dbKey := range dbKeys { + if dbKey.ID == 0 && dbKey.KeyID == "" { + continue + } + if err := dbKey.AfterFind(nil); err != nil { + return nil, err + } + keys = append(keys, schemaKeyFromTableKey(dbKey)) + } + + return keys, nil +} + +func (s *RDBConfigStore) getProviderKeyByName(ctx context.Context, txDB *gorm.DB, provider schemas.ModelProvider, keyID string) (*tables.TableKey, error) { + var dbKey tables.TableKey + if err := txDB.WithContext(ctx). + Table("config_keys"). + Select("config_keys.*"). + Joins("JOIN config_providers ON config_providers.id = config_keys.provider_id"). + Where("config_providers.name = ? AND config_keys.key_id = ?", string(provider), keyID). + First(&dbKey).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, err + } + return &dbKey, nil +} + +// GetProviderKey retrieves a single key for a provider. +func (s *RDBConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { + dbKey, err := s.getProviderKeyByName(ctx, s.db, provider, keyID) + if err != nil { + return nil, err + } + + key := schemaKeyFromTableKey(*dbKey) + return &key, nil +} + +// CreateProviderKey creates a new key for an existing provider. +func (s *RDBConfigStore) CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + var dbProvider tables.TableProvider + if err := txDB.WithContext(ctx).Where("name = ?", string(provider)).First(&dbProvider).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrNotFound + } + return err + } + dbKey, err := tableKeyFromSchemaKey(dbProvider, key) + if err != nil { + return err + } + if err := txDB.WithContext(ctx).Create(&dbKey).Error; err != nil { + return s.parseGormError(err) + } + return nil +} + +// UpdateProviderKey updates a single key for an existing provider. +func (s *RDBConfigStore) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + + existingKey, err := s.getProviderKeyByName(ctx, txDB, provider, keyID) + if err != nil { + return err + } + + dbKey, err := tableKeyFromSchemaKey(tables.TableProvider{ + ID: existingKey.ProviderID, + Name: existingKey.Provider, + }, key) + if err != nil { + return err + } + dbKey.ID = existingKey.ID + dbKey.KeyID = existingKey.KeyID + dbKey.ProviderID = existingKey.ProviderID + dbKey.Provider = existingKey.Provider + dbKey.ConfigHash = existingKey.ConfigHash + dbKey.EncryptionStatus = existingKey.EncryptionStatus + dbKey.CreatedAt = existingKey.CreatedAt // Preserve original creation timestamp + + if err := txDB.WithContext(ctx).Save(&dbKey).Error; err != nil { + return s.parseGormError(err) + } + + return nil +} + +// DeleteProviderKey deletes a single key for an existing provider. +func (s *RDBConfigStore) DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + + providerIDSubquery := txDB.Model(&tables.TableProvider{}). + Select("id"). + Where("name = ?", string(provider)) + + result := txDB.WithContext(ctx). + Where("provider_id = (?) AND key_id = ?", providerIDSubquery, keyID). + Delete(&tables.TableKey{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return ErrNotFound + } + + return nil +} + // GetProviders retrieves all providers from the database with their governance relationships. func (s *RDBConfigStore) GetProviders(ctx context.Context) ([]tables.TableProvider, error) { var providers []tables.TableProvider @@ -880,26 +1087,25 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, // This will never happen, but just in case. clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) for i, dbClient := range dbMCPClients { - // Dereference IsPingAvailable pointer, defaulting to true if nil - isPingAvailable := true - if dbClient.IsPingAvailable != nil { - isPingAvailable = *dbClient.IsPingAvailable - } clientConfigs[i] = &schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - IsCodeModeClient: dbClient.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: dbClient.ConnectionString, - StdioConfig: dbClient.StdioConfig, - AuthType: schemas.MCPAuthType(dbClient.AuthType), - OauthConfigID: dbClient.OauthConfigID, - ToolsToExecute: dbClient.ToolsToExecute, - ToolsToAutoExecute: dbClient.ToolsToAutoExecute, - Headers: dbClient.Headers, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, - ToolPricing: dbClient.ToolPricing, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + AllowedExtraHeaders: dbClient.AllowedExtraHeaders, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, + AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, + DiscoveredTools: dbClient.DiscoveredTools, + DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping, } } return &schemas.MCPConfig{ @@ -913,32 +1119,32 @@ func (s *RDBConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, return nil, err } toolManagerConfig := schemas.MCPToolManagerConfig{ - ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second, - MaxAgentDepth: clientConfig.MCPAgentDepth, - CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), + ToolExecutionTimeout: time.Duration(clientConfig.MCPToolExecutionTimeout) * time.Second, + MaxAgentDepth: clientConfig.MCPAgentDepth, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(clientConfig.MCPCodeModeBindingLevel), + DisableAutoToolInject: clientConfig.MCPDisableAutoToolInject, } clientConfigs := make([]*schemas.MCPClientConfig, len(dbMCPClients)) for i, dbClient := range dbMCPClients { - // Dereference IsPingAvailable pointer, defaulting to true if nil - isPingAvailable := true - if dbClient.IsPingAvailable != nil { - isPingAvailable = *dbClient.IsPingAvailable - } clientConfigs[i] = &schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - IsCodeModeClient: dbClient.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: dbClient.ConnectionString, - StdioConfig: dbClient.StdioConfig, - AuthType: schemas.MCPAuthType(dbClient.AuthType), - OauthConfigID: dbClient.OauthConfigID, - ToolsToExecute: dbClient.ToolsToExecute, - ToolsToAutoExecute: dbClient.ToolsToAutoExecute, - Headers: dbClient.Headers, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, - ToolPricing: dbClient.ToolPricing, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + AllowedExtraHeaders: dbClient.AllowedExtraHeaders, + IsPingAvailable: dbClient.IsPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, + ToolPricing: dbClient.ToolPricing, + DiscoveredTools: dbClient.DiscoveredTools, + DiscoveredToolNameMapping: dbClient.DiscoveredToolNameMapping, } } return &schemas.MCPConfig{ @@ -1023,19 +1229,21 @@ func (s *RDBConfigStore) CreateMCPClientConfig(ctx context.Context, clientConfig } // Create new client dbClient := tables.TableMCPClient{ - ClientID: clientConfigCopy.ID, - Name: clientConfigCopy.Name, - IsCodeModeClient: clientConfigCopy.IsCodeModeClient, - ConnectionType: string(clientConfigCopy.ConnectionType), - ConnectionString: clientConfigCopy.ConnectionString, - StdioConfig: clientConfigCopy.StdioConfig, - AuthType: string(clientConfigCopy.AuthType), - OauthConfigID: clientConfigCopy.OauthConfigID, - ToolsToExecute: clientConfigCopy.ToolsToExecute, - ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, - Headers: clientConfigCopy.Headers, - IsPingAvailable: &clientConfigCopy.IsPingAvailable, - ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), + ClientID: clientConfigCopy.ID, + Name: clientConfigCopy.Name, + IsCodeModeClient: clientConfigCopy.IsCodeModeClient, + ConnectionType: string(clientConfigCopy.ConnectionType), + ConnectionString: clientConfigCopy.ConnectionString, + StdioConfig: clientConfigCopy.StdioConfig, + AuthType: string(clientConfigCopy.AuthType), + OauthConfigID: clientConfigCopy.OauthConfigID, + ToolsToExecute: clientConfigCopy.ToolsToExecute, + ToolsToAutoExecute: clientConfigCopy.ToolsToAutoExecute, + Headers: clientConfigCopy.Headers, + AllowedExtraHeaders: clientConfigCopy.AllowedExtraHeaders, + IsPingAvailable: clientConfigCopy.IsPingAvailable, + ToolSyncInterval: int(clientConfigCopy.ToolSyncInterval.Minutes()), + AllowOnAllVirtualKeys: clientConfigCopy.AllowOnAllVirtualKeys, } if err := tx.WithContext(ctx).Create(&dbClient).Error; err != nil { return s.parseGormError(err) @@ -1094,6 +1302,13 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c if err != nil { return fmt.Errorf("failed to marshal headers: %w", err) } + if clientConfigCopy.AllowedExtraHeaders == nil { + clientConfigCopy.AllowedExtraHeaders = []string{} + } + allowedExtraHeadersJSON, err := json.Marshal(clientConfigCopy.AllowedExtraHeaders) + if err != nil { + return fmt.Errorf("failed to marshal allowed_extra_headers: %w", err) + } if clientConfigCopy.ToolPricing == nil { clientConfigCopy.ToolPricing = map[string]float64{} @@ -1120,8 +1335,10 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c "tools_to_execute_json": string(toolsToExecuteJSON), "tools_to_auto_execute_json": string(toolsToAutoExecuteJSON), "headers_json": headersJSONStr, + "allowed_extra_headers_json": string(allowedExtraHeadersJSON), "tool_pricing_json": string(toolPricingJSON), "tool_sync_interval": clientConfigCopy.ToolSyncInterval, + "allow_on_all_virtual_keys": clientConfigCopy.AllowOnAllVirtualKeys, "updated_at": time.Now(), } if encrypt.IsEnabled() { @@ -1141,6 +1358,26 @@ func (s *RDBConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, c }) } +// UpdateMCPClientDiscoveredTools persists discovered tools for a per-user OAuth MCP client. +func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error { + toolsJSON, err := json.Marshal(tools) + if err != nil { + return fmt.Errorf("failed to marshal discovered tools: %w", err) + } + mappingJSON, err := json.Marshal(toolNameMapping) + if err != nil { + return fmt.Errorf("failed to marshal tool name mapping: %w", err) + } + return s.db.WithContext(ctx). + Model(&tables.TableMCPClient{}). + Where("client_id = ?", clientID). + Updates(map[string]interface{}{ + "discovered_tools_json": string(toolsJSON), + "tool_name_mapping_json": string(mappingJSON), + "updated_at": time.Now(), + }).Error +} + // DeleteMCPClientConfig deletes an MCP client configuration from the database. func (s *RDBConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { return s.db.Transaction(func(tx *gorm.DB) error { @@ -1191,7 +1428,7 @@ func (s *RDBConfigStore) UpdateVectorStoreConfig(ctx context.Context, config *ve if err != nil { return err } - var record = &tables.TableVectorStoreConfig{ + record := &tables.TableVectorStoreConfig{ Type: string(config.Type), Enabled: config.Enabled, Config: jsonConfig, @@ -1230,7 +1467,7 @@ func (s *RDBConfigStore) UpdateLogsStoreConfig(ctx context.Context, config *logs if err != nil { return err } - var record = &tables.TableLogStoreConfig{ + record := &tables.TableLogStoreConfig{ Enabled: config.Enabled, Type: string(config.Type), Config: jsonConfig, @@ -1315,6 +1552,130 @@ func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error } +func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) { + var overrides []tables.TablePricingOverride + q := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) + if filters.ScopeKind != nil { + q = q.Where("scope_kind = ?", *filters.ScopeKind) + } + if filters.VirtualKeyID != nil { + q = q.Where("virtual_key_id = ?", *filters.VirtualKeyID) + } + if filters.ProviderID != nil { + q = q.Where("provider_id = ?", *filters.ProviderID) + } + if filters.ProviderKeyID != nil { + q = q.Where("provider_key_id = ?", *filters.ProviderKeyID) + } + if err := q.Order("created_at ASC").Find(&overrides).Error; err != nil { + return nil, s.parseGormError(err) + } + return overrides, nil +} + +func (s *RDBConfigStore) GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) { + baseQuery := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) + + if params.Search != "" { + search := "%" + strings.ToLower(params.Search) + "%" + baseQuery = baseQuery.Where("LOWER(name) LIKE ?", search) + } + if params.ScopeKind != nil { + baseQuery = baseQuery.Where("scope_kind = ?", *params.ScopeKind) + } + if params.VirtualKeyID != nil { + baseQuery = baseQuery.Where("virtual_key_id = ?", *params.VirtualKeyID) + } + if params.ProviderID != nil { + baseQuery = baseQuery.Where("provider_id = ?", *params.ProviderID) + } + if params.ProviderKeyID != nil { + baseQuery = baseQuery.Where("provider_key_id = ?", *params.ProviderKeyID) + } + + var totalCount int64 + if err := baseQuery.Count(&totalCount).Error; err != nil { + return nil, 0, err + } + + limit := params.Limit + offset := params.Offset + + if limit <= 0 { + limit = 25 + } else if limit > 100 { + limit = 100 + } + + if offset < 0 { + offset = 0 + } + + var overrides []tables.TablePricingOverride + if err := baseQuery. + Order("created_at ASC"). + Offset(offset). + Limit(limit). + Find(&overrides).Error; err != nil { + return nil, 0, s.parseGormError(err) + } + return overrides, totalCount, nil +} + +func (s *RDBConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) { + var override tables.TablePricingOverride + if err := s.db.WithContext(ctx).First(&override, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, s.parseGormError(err) + } + return &override, nil +} + +func (s *RDBConfigStore) CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + if err := txDB.WithContext(ctx).Create(override).Error; err != nil { + return s.parseGormError(err) + } + return nil +} + +func (s *RDBConfigStore) UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + if err := txDB.WithContext(ctx).Save(override).Error; err != nil { + return s.parseGormError(err) + } + return nil +} + +func (s *RDBConfigStore) DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error { + var txDB *gorm.DB + if len(tx) > 0 { + txDB = tx[0] + } else { + txDB = s.db + } + res := txDB.WithContext(ctx).Delete(&tables.TablePricingOverride{}, "id = ?", id) + if res.Error != nil { + return s.parseGormError(res.Error) + } + if res.RowsAffected == 0 { + return ErrNotFound + } + return nil +} + // MODEL PARAMETERS METHODS // GetModelParameters retrieves model parameters for a specific model. @@ -1447,35 +1808,31 @@ func (s *RDBConfigStore) UpdatePlugin(ctx context.Context, plugin *tables.TableP txDB = s.db.Begin() localTx = true } - // Mark plugin as custom if path is not empty if plugin.Path != nil && strings.TrimSpace(*plugin.Path) != "" { plugin.IsCustom = true } else { plugin.IsCustom = false } - if err := txDB.WithContext(ctx).Delete(&tables.TablePlugin{}, "name = ?", plugin.Name).Error; err != nil { if localTx { txDB.Rollback() } return err } - if err := txDB.WithContext(ctx).Create(plugin).Error; err != nil { if localTx { txDB.Rollback() } return s.parseGormError(err) } - if localTx { return txDB.Commit().Error } - return nil } +// DeletePlugin deletes a plugin from the database. func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { @@ -1488,6 +1845,7 @@ func (s *RDBConfigStore) DeletePlugin(ctx context.Context, name string, tx ...*g // GOVERNANCE METHODS +// GetRedactedVirtualKeys retrieves redacted virtual keys from the database. func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []string) ([]tables.TableVirtualKey, error) { var virtualKeys []tables.TableVirtualKey @@ -1505,6 +1863,7 @@ func (s *RDBConfigStore) GetRedactedVirtualKeys(ctx context.Context, ids []strin return virtualKeys, nil } +// preloadCustomerRelations preloads the customer relations for a virtual key. func preloadCustomerRelations(db *gorm.DB, prefix string) *gorm.DB { relation := func(name string) string { if prefix == "" { @@ -1512,7 +1871,6 @@ func preloadCustomerRelations(db *gorm.DB, prefix string) *gorm.DB { } return prefix + name } - return db. Preload(relation("Teams")). Preload(relation("Budget")). @@ -1520,16 +1878,16 @@ func preloadCustomerRelations(db *gorm.DB, prefix string) *gorm.DB { Preload(relation("VirtualKeys")) } +// preloadVirtualKeyBaseRelations preloads the base relationships for a virtual key. func preloadVirtualKeyBaseRelations(db *gorm.DB) *gorm.DB { - db = db.Preload("Team").Preload("Team.Customer") - - db = db.Preload("Customer") - return db. - Preload("Budget"). + Preload("Team"). + Preload("Team.Customer"). + Preload("Customer"). + Preload("Budgets"). Preload("RateLimit"). Preload("ProviderConfigs"). - Preload("ProviderConfigs.Budget"). + Preload("ProviderConfigs.Budgets"). Preload("ProviderConfigs.RateLimit"). Preload("ProviderConfigs.Keys", func(db *gorm.DB) *gorm.DB { return db.Select("id, name, key_id, models_json, provider") @@ -1538,6 +1896,7 @@ func preloadVirtualKeyBaseRelations(db *gorm.DB) *gorm.DB { Preload("MCPConfigs.MCPClient") } +// preloadVirtualKeyDetailRelations preloads the detail relationships for a virtual key. func preloadVirtualKeyDetailRelations(db *gorm.DB) *gorm.DB { return preloadCustomerRelations(preloadVirtualKeyBaseRelations(db), "Customer.") } @@ -1657,7 +2016,6 @@ func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) valueHash := encrypt.HashSHA256(value) var virtualKey tables.TableVirtualKey query := preloadVirtualKeyBaseRelations(s.db.WithContext(ctx)) - // Use hash-based lookup if hash column is populated, fall back to plaintext for backward compat if err := query.Where("value_hash = ?", valueHash).First(&virtualKey).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1675,6 +2033,7 @@ func (s *RDBConfigStore) GetVirtualKeyByValue(ctx context.Context, value string) return &virtualKey, nil } +// CreateVirtualKey creates a new virtual key in the database. func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { @@ -1688,6 +2047,7 @@ func (s *RDBConfigStore) CreateVirtualKey(ctx context.Context, virtualKey *table return nil } +// UpdateVirtualKey updates an existing virtual key in the database. func (s *RDBConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *tables.TableVirtualKey, tx ...*gorm.DB) error { var txDB *gorm.DB if len(tx) > 0 { @@ -1788,33 +2148,26 @@ func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error return err } - // Collect budget and rate limit IDs from provider configs before deletion - var providerConfigBudgetIDs []string + // Delete provider config resources before deleting the configs themselves var providerConfigRateLimitIDs []string for _, pc := range virtualKey.ProviderConfigs { // Delete the keys join table entries if err := tx.WithContext(ctx).Exec("DELETE FROM governance_virtual_key_provider_config_keys WHERE table_virtual_key_provider_config_id = ?", pc.ID).Error; err != nil { return err } - // Collect budget and rate limit IDs for deletion after provider config - if pc.BudgetID != nil { - providerConfigBudgetIDs = append(providerConfigBudgetIDs, *pc.BudgetID) + // Delete budgets owned by this provider config + if err := tx.WithContext(ctx).Where("provider_config_id = ?", pc.ID).Delete(&tables.TableBudget{}).Error; err != nil { + return err } if pc.RateLimitID != nil { providerConfigRateLimitIDs = append(providerConfigRateLimitIDs, *pc.RateLimitID) } } - // Delete all provider configs associated with the virtual key first + // Delete all provider configs associated with the virtual key if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "virtual_key_id = ?", id).Error; err != nil { return err } - // Now delete the collected budgets and rate limits - for _, budgetID := range providerConfigBudgetIDs { - if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", budgetID).Error; err != nil { - return err - } - } for _, rateLimitID := range providerConfigRateLimitIDs { if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", rateLimitID).Error; err != nil { return err @@ -1824,8 +2177,10 @@ func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKeyMCPConfig{}, "virtual_key_id = ?", id).Error; err != nil { return err } - // Delete the budget associated with the virtual key - budgetID := virtualKey.BudgetID + // Delete budgets owned by this virtual key + if err := tx.WithContext(ctx).Where("virtual_key_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil { + return err + } rateLimitID := virtualKey.RateLimitID // Delete the virtual key if err := tx.WithContext(ctx).Delete(&tables.TableVirtualKey{}, "id = ?", id).Error; err != nil { @@ -1834,11 +2189,6 @@ func (s *RDBConfigStore) DeleteVirtualKey(ctx context.Context, id string) error } return err } - if budgetID != nil { - if err := tx.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { - return err - } - } // Delete the rate limit associated with the virtual key if rateLimitID != nil { if err := tx.WithContext(ctx).Delete(&tables.TableRateLimit{}, "id = ?", *rateLimitID).Error; err != nil { @@ -2036,18 +2386,15 @@ func (s *RDBConfigStore) DeleteVirtualKeyProviderConfig(ctx context.Context, id } return err } - // Store the budget and rate limit IDs before deleting - budgetID := providerConfig.BudgetID + // Store the rate limit ID before deleting rateLimitID := providerConfig.RateLimitID - // Delete the provider config first - if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil { + // Delete budgets owned by this provider config + if err := txDB.WithContext(ctx).Where("provider_config_id = ?", id).Delete(&tables.TableBudget{}).Error; err != nil { return err } - // Delete the budget if it exists - if budgetID != nil { - if err := txDB.WithContext(ctx).Delete(&tables.TableBudget{}, "id = ?", *budgetID).Error; err != nil { - return err - } + // Delete the provider config + if err := txDB.WithContext(ctx).Delete(&tables.TableVirtualKeyProviderConfig{}, "id = ?", id).Error; err != nil { + return err } // Delete the rate limit if it exists if rateLimitID != nil { @@ -2071,12 +2418,33 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKey return nil, nil } var mcpConfigs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { + if err := s.db.WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { return nil, err } return mcpConfigs, nil } +// GetVirtualKeyMCPConfigsByMCPClientID retrieves all VK MCP configs for a given MCP client. +func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) { + var configs []tables.TableVirtualKeyMCPConfig + if err := s.db.WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Find(&configs).Error; err != nil { + return nil, err + } + return configs, nil +} + +// GetVirtualKeyMCPConfigsByMCPClientIDs retrieves all VK MCP configs for a set of MCP client IDs in one query. +func (s *RDBConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error) { + if len(mcpClientIDs) == 0 { + return nil, nil + } + var configs []tables.TableVirtualKeyMCPConfig + if err := s.db.WithContext(ctx).Where("mcp_client_id IN ?", mcpClientIDs).Find(&configs).Error; err != nil { + return nil, err + } + return configs, nil +} + // CreateVirtualKeyMCPConfig creates a new virtual key MCP config in the database. func (s *RDBConfigStore) CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error { var txDB *gorm.DB @@ -2850,6 +3218,7 @@ func (s *RDBConfigStore) GetModelConfigs(ctx context.Context) ([]tables.TableMod return modelConfigs, nil } +// GetModelConfigsPaginated retrieves model configs with pagination, filtering, and search support. func (s *RDBConfigStore) GetModelConfigsPaginated(ctx context.Context, params ModelConfigsQueryParams) ([]tables.TableModelConfig, int64, error) { baseQuery := s.db.WithContext(ctx).Model(&tables.TableModelConfig{}) @@ -3010,6 +3379,7 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo var modelConfigs []tables.TableModelConfig var providers []tables.TableProvider var routingRules []tables.TableRoutingRule + var pricingOverrides []tables.TablePricingOverride var governanceConfigs []tables.TableGovernanceConfig if err := s.db.WithContext(ctx). @@ -3041,12 +3411,15 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo if err := s.loadRoutingRulesOrdered(ctx, &routingRules); err != nil { return nil, err } + if err := s.db.WithContext(ctx).Find(&pricingOverrides).Error; err != nil { + return nil, err + } // Fetching governance config for username and password if err := s.db.WithContext(ctx).Find(&governanceConfigs).Error; err != nil { return nil, err } // Check if any config is present - if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(modelConfigs) == 0 && len(providers) == 0 && len(governanceConfigs) == 0 && len(routingRules) == 0 { + if len(virtualKeys) == 0 && len(teams) == 0 && len(customers) == 0 && len(budgets) == 0 && len(rateLimits) == 0 && len(modelConfigs) == 0 && len(providers) == 0 && len(governanceConfigs) == 0 && len(routingRules) == 0 && len(pricingOverrides) == 0 { return nil, nil } var authConfig *AuthConfig @@ -3078,15 +3451,16 @@ func (s *RDBConfigStore) GetGovernanceConfig(ctx context.Context) (*GovernanceCo } } return &GovernanceConfig{ - VirtualKeys: virtualKeys, - Teams: teams, - Customers: customers, - Budgets: budgets, - RateLimits: rateLimits, - ModelConfigs: modelConfigs, - Providers: providers, - RoutingRules: routingRules, - AuthConfig: authConfig, + VirtualKeys: virtualKeys, + Teams: teams, + Customers: customers, + Budgets: budgets, + RateLimits: rateLimits, + ModelConfigs: modelConfigs, + Providers: providers, + RoutingRules: routingRules, + PricingOverrides: pricingOverrides, + AuthConfig: authConfig, }, nil } @@ -3105,7 +3479,6 @@ func (s *RDBConfigStore) GetAuthConfig(ctx context.Context) (*AuthConfig, error) if !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } - } if err := s.db.WithContext(ctx).First(&tables.TableGovernanceConfig{}, "key = ?", tables.ConfigIsAuthEnabledKey).Select("value").Scan(&isEnabled).Error; err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { @@ -3601,3 +3974,486 @@ func (s *RDBConfigStore) GetOauthConfigByTokenID(ctx context.Context, tokenID st } return &config, nil } + +// ---------- Per-User OAuth Session CRUD ---------- + +// GetOauthUserSessionByID retrieves a per-user OAuth session by its ID +func (s *RDBConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { + var session tables.TableOauthUserSession + result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get oauth user session: %w", result.Error) + } + return &session, nil +} + +// GetOauthUserSessionByState retrieves a per-user OAuth session by its state token +func (s *RDBConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { + var session tables.TableOauthUserSession + result := s.db.WithContext(ctx).Where("state = ?", state).First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get oauth user session by state: %w", result.Error) + } + return &session, nil +} + +// ClaimOauthUserSessionByState atomically claims a pending per-user OAuth session by its state token. +// Returns nil if the session doesn't exist or has already been claimed by another request. +func (s *RDBConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { + var session tables.TableOauthUserSession + result := s.db.WithContext(ctx).Where("state = ? AND status = ?", state, "pending").First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to claim oauth user session by state: %w", result.Error) + } + // Atomically transition from "pending" to "claiming" to prevent concurrent claims + updateResult := s.db.WithContext(ctx).Model(&tables.TableOauthUserSession{}). + Where("id = ? AND status = ?", session.ID, "pending"). + Update("status", "claiming") + if updateResult.Error != nil { + return nil, fmt.Errorf("failed to claim oauth user session: %w", updateResult.Error) + } + if updateResult.RowsAffected == 0 { + return nil, nil // Another request already claimed this session + } + session.Status = "claiming" + return &session, nil +} + +// GetOauthUserSessionBySessionToken retrieves a per-user OAuth session by its Bifrost session token (hashed lookup) +func (s *RDBConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { + var session tables.TableOauthUserSession + tokenHash := encrypt.HashSHA256(sessionToken) + result := s.db.WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get oauth user session by session token: %w", result.Error) + } + return &session, nil +} + +// CreateOauthUserSession creates a new per-user OAuth session +func (s *RDBConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { + result := s.db.WithContext(ctx).Create(session) + if result.Error != nil { + return fmt.Errorf("failed to create oauth user session: %w", result.Error) + } + return nil +} + +// UpdateOauthUserSession updates an existing per-user OAuth session +func (s *RDBConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { + result := s.db.WithContext(ctx).Save(session) + if result.Error != nil { + return fmt.Errorf("failed to update oauth user session: %w", result.Error) + } + return nil +} + +// ---------- Per-User OAuth Token CRUD ---------- + +// GetOauthUserTokenBySessionToken retrieves a per-user OAuth token by its Bifrost session token +// GetOauthUserTokenByIdentity looks up an upstream OAuth token by user identity and MCP client. +// Priority: userID > virtualKeyID > sessionToken (fallback for anonymous users). +func (s *RDBConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) { + var token tables.TableOauthUserToken + var result *gorm.DB + + if userID != "" { + result = s.db.WithContext(ctx).Where("user_id = ? AND mcp_client_id = ?", userID, mcpClientID).First(&token) + } else if virtualKeyID != "" { + result = s.db.WithContext(ctx).Where("virtual_key_id = ? AND mcp_client_id = ?", virtualKeyID, mcpClientID).First(&token) + } else if sessionToken != "" { + result = s.db.WithContext(ctx).Where("session_token = ? AND mcp_client_id = ?", sessionToken, mcpClientID).First(&token) + } else { + return nil, nil + } + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get oauth user token by identity: %w", result.Error) + } + return &token, nil +} + +func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { + var token tables.TableOauthUserToken + tokenHash := encrypt.HashSHA256(sessionToken) + result := s.db.WithContext(ctx).Where("session_token_hash = ?", tokenHash).First(&token) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get oauth user token by session token: %w", result.Error) + } + return &token, nil +} + +// CreateOauthUserToken creates or replaces a per-user OAuth token. +// When an identity (VirtualKeyID or UserID) is set, any existing token for the +// same identity + MCPClientID pair is replaced to keep resolution deterministic. +func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { + // Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing + // duplicate tokens when concurrent requests race on the same identity+client pair. + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if token.UserID != nil && *token.UserID != "" { + var existing tables.TableOauthUserToken + err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error + if err == nil { + token.ID = existing.ID // reuse the row + return tx.Save(token).Error + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to query oauth user token: %w", err) + } + } else if token.VirtualKeyID != nil && *token.VirtualKeyID != "" { + var existing tables.TableOauthUserToken + err := tx.Where("virtual_key_id = ? AND mcp_client_id = ?", *token.VirtualKeyID, token.MCPClientID).First(&existing).Error + if err == nil { + token.ID = existing.ID // reuse the row + return tx.Save(token).Error + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to query oauth user token: %w", err) + } + } + + if err := tx.Create(token).Error; err != nil { + return fmt.Errorf("failed to create oauth user token: %w", err) + } + return nil + }) +} + +// UpdateOauthUserToken updates an existing per-user OAuth token +func (s *RDBConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { + result := s.db.WithContext(ctx).Save(token) + if result.Error != nil { + return fmt.Errorf("failed to update oauth user token: %w", result.Error) + } + return nil +} + +// DeleteOauthUserToken deletes a per-user OAuth token by its ID +func (s *RDBConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { + result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TableOauthUserToken{}) + if result.Error != nil { + return fmt.Errorf("failed to delete oauth user token: %w", result.Error) + } + return nil +} + +// DeleteOauthUserTokensByMCPClient deletes all per-user OAuth tokens for a specific MCP client +func (s *RDBConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { + result := s.db.WithContext(ctx).Where("mcp_client_id = ?", mcpClientID).Delete(&tables.TableOauthUserToken{}) + if result.Error != nil { + return fmt.Errorf("failed to delete oauth user tokens for mcp client: %w", result.Error) + } + return nil +} + +// ---------- Per-User OAuth Authorization Server CRUD ---------- + +// GetPerUserOAuthClientByClientID retrieves a dynamically registered OAuth client by its client_id. +func (s *RDBConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { + var client tables.TablePerUserOAuthClient + result := s.db.WithContext(ctx).Where("client_id = ?", clientID).First(&client) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get per-user oauth client: %w", result.Error) + } + return &client, nil +} + +// CreatePerUserOAuthClient creates a new dynamically registered OAuth client. +func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { + result := s.db.WithContext(ctx).Create(client) + if result.Error != nil { + return fmt.Errorf("failed to create per-user oauth client: %w", result.Error) + } + return nil +} + +// GetPerUserOAuthSessionByAccessToken retrieves a Bifrost-issued session by its access token. +func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { + var session tables.TablePerUserOAuthSession + tokenHash := encrypt.HashSHA256(accessToken) + result := s.db.WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { + return db.Select("id, name, value, encryption_status") + }).First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get per-user oauth session: %w", result.Error) + } + return &session, nil +} + +// GetPerUserOAuthSessionByID retrieves a Bifrost-issued session by its ID. +func (s *RDBConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { + var session tables.TablePerUserOAuthSession + result := s.db.WithContext(ctx).Where("id = ?", id).First(&session) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get per-user oauth session by id: %w", result.Error) + } + return &session, nil +} + +// CreatePerUserOAuthSession creates a new Bifrost-issued OAuth session. +func (s *RDBConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { + result := s.db.WithContext(ctx).Create(session) + if result.Error != nil { + return fmt.Errorf("failed to create per-user oauth session: %w", result.Error) + } + return nil +} + +// UpdatePerUserOAuthSession updates a Bifrost-issued OAuth session (e.g., to attach user identity). +func (s *RDBConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { + result := s.db.WithContext(ctx).Save(session) + if result.Error != nil { + return fmt.Errorf("failed to update per-user oauth session: %w", result.Error) + } + return nil +} + +// DeletePerUserOAuthSession deletes a Bifrost-issued OAuth session by ID. +func (s *RDBConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { + result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthSession{}) + if result.Error != nil { + return fmt.Errorf("failed to delete per-user oauth session: %w", result.Error) + } + return nil +} + +// GetPerUserOAuthCodeByCode retrieves an authorization code record. +func (s *RDBConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { + var codeRecord tables.TablePerUserOAuthCode + codeHash := encrypt.HashSHA256(code) + result := s.db.WithContext(ctx).Where("code_hash = ?", codeHash).First(&codeRecord) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get per-user oauth code: %w", result.Error) + } + return &codeRecord, nil +} + +// CreatePerUserOAuthCode creates a new authorization code record. +func (s *RDBConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { + result := s.db.WithContext(ctx).Create(code) + if result.Error != nil { + return fmt.Errorf("failed to create per-user oauth code: %w", result.Error) + } + return nil +} + +// ClaimPerUserOAuthCode atomically marks an authorization code as used. +// Returns the code record if successfully claimed, nil if already used or not found. +func (s *RDBConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { + codeHash := encrypt.HashSHA256(code) + var codeRecord tables.TablePerUserOAuthCode + result := s.db.WithContext(ctx).Where("code_hash = ? AND used = ?", codeHash, false).First(&codeRecord) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to find per-user oauth code: %w", result.Error) + } + // Atomically mark as used + updateResult := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthCode{}). + Where("id = ? AND used = ?", codeRecord.ID, false). + Update("used", true) + if updateResult.Error != nil { + return nil, fmt.Errorf("failed to claim per-user oauth code: %w", updateResult.Error) + } + if updateResult.RowsAffected == 0 { + return nil, nil // Another request already claimed it + } + codeRecord.Used = true + return &codeRecord, nil +} + +// UpdatePerUserOAuthCode updates an authorization code record (e.g., marking as used). +func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { + result := s.db.WithContext(ctx).Save(code) + if result.Error != nil { + return fmt.Errorf("failed to update per-user oauth code: %w", result.Error) + } + return nil +} + +// ---------- Per-User OAuth Pending Flow CRUD ---------- + +// GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID. +func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { + var flow tables.TablePerUserOAuthPendingFlow + result := s.db.WithContext(ctx).Where("id = ?", id).First(&flow) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get per-user oauth pending flow: %w", result.Error) + } + return &flow, nil +} + +// CreatePerUserOAuthPendingFlow persists a new pending consent flow. +func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + result := s.db.WithContext(ctx).Create(flow) + if result.Error != nil { + return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error) + } + return nil +} + +// UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step). +func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + result := s.db.WithContext(ctx).Save(flow) + if result.Error != nil { + return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error) + } + return nil +} + +// DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted. +func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { + result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) + if result.Error != nil { + return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error) + } + return nil +} + +func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { + now := time.Now().UTC() + result := s.db.WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) + if result.Error != nil { + return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) + } + if result.RowsAffected == 0 { + // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). + var count int64 + if err := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { + return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) + } + if count > 0 { + return 0, schemas.ErrPerUserOAuthPendingFlowExpired + } + } + return result.RowsAffected, nil +} + +// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session, +// and creates the authorization code in a single transaction. +func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { + var rowsAffected int64 + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 1. Consume the pending flow (atomic idempotency guard). + // Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check. + now := time.Now().UTC() + result := tx.Where("id = ? AND expires_at > ?", flowID, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) + if result.Error != nil { + return fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) + } + rowsAffected = result.RowsAffected + if rowsAffected == 0 { + // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). + var count int64 + if err := tx.Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", flowID).Count(&count).Error; err != nil { + return fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) + } + if count > 0 { + return schemas.ErrPerUserOAuthPendingFlowExpired + } + // Record gone β€” consumed by a concurrent request; caller treats as conflict. + return nil + } + + // 2. Create the Bifrost session. + if err := tx.Create(session).Error; err != nil { + return fmt.Errorf("failed to create per-user oauth session: %w", err) + } + + // 3. Create the authorization code. + if err := tx.Create(code).Error; err != nil { + return fmt.Errorf("failed to create per-user oauth code: %w", err) + } + + return nil + }) + if err != nil { + return 0, err + } + return rowsAffected, nil +} + +// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID. +func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { + if strings.TrimSpace(gatewaySessionID) == "" { + return nil, fmt.Errorf("gateway session id is required") + } + // Find all tokens whose session_token_hash matches any upstream session + // linked to this gateway session ID. This supports per-service proxy tokens + // (e.g. "flow::") where each MCP service gets its own hash. + var tokens []tables.TableOauthUserToken + subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.db.WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) + if result.Error != nil { + return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error) + } + return tokens, nil +} + +// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions +// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID. +func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { + if strings.TrimSpace(gatewaySessionID) == "" { + return fmt.Errorf("gateway session id is required") + } + if strings.TrimSpace(realSessionToken) == "" { + return fmt.Errorf("real session token is required") + } + realTokenHash := encrypt.HashSHA256(realSessionToken) + + // Always overwrite both identity columns from the finalized values so stale + // identities from a prior flow phase cannot persist and cause GetOauthUserTokenByIdentity + // to resolve this token under the wrong identity. + updates := map[string]interface{}{ + "session_token": realSessionToken, + "session_token_hash": realTokenHash, + "virtual_key_id": virtualKeyID, + "user_id": userID, + } + + // Update all tokens whose session_token_hash matches any upstream session + // linked to this gateway session ID. + subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.db.WithContext(ctx).Model(&tables.TableOauthUserToken{}). + Where("session_token_hash IN (?)", subquery). + Updates(updates) + if result.Error != nil { + return fmt.Errorf("failed to transfer oauth user tokens from gateway session: %w", result.Error) + } + s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected) + return nil +} diff --git a/framework/configstore/rdb_test.go b/framework/configstore/rdb_test.go index a2d25600cf..406e7a5cfd 100644 --- a/framework/configstore/rdb_test.go +++ b/framework/configstore/rdb_test.go @@ -201,6 +201,86 @@ func TestUpdateProvidersConfig_MultipleKeys(t *testing.T) { assert.Len(t, result["anthropic"].Keys, 1) } +func TestProviderKeyCRUD(t *testing.T) { + store := setupRDBTestStore(t) + ctx := context.Background() + + err := store.UpdateProvidersConfig(ctx, map[schemas.ModelProvider]ProviderConfig{ + "openai": {}, + }) + require.NoError(t, err) + + keys, err := store.GetProviderKeys(ctx, "openai") + require.NoError(t, err) + assert.Empty(t, keys) + + key := schemas.Key{ + ID: "key-uuid-1", + Name: "openai-primary", + Value: *schemas.NewEnvVar("sk-test-key-v1"), + Weight: 1.0, + } + + err = store.CreateProviderKey(ctx, "openai", key) + require.NoError(t, err) + + keys, err = store.GetProviderKeys(ctx, "openai") + require.NoError(t, err) + require.Len(t, keys, 1) + assert.Equal(t, "openai-primary", keys[0].Name) + + storedKey, err := store.GetProviderKey(ctx, "openai", key.ID) + require.NoError(t, err) + require.NotNil(t, storedKey) + assert.Equal(t, "sk-test-key-v1", storedKey.Value.Val) + + key.Value = *schemas.NewEnvVar("sk-test-key-v2") + key.Weight = 2.0 + + err = store.UpdateProviderKey(ctx, "openai", key.ID, key) + require.NoError(t, err) + + storedKey, err = store.GetProviderKey(ctx, "openai", key.ID) + require.NoError(t, err) + require.NotNil(t, storedKey) + assert.Equal(t, "sk-test-key-v2", storedKey.Value.Val) + assert.Equal(t, 2.0, storedKey.Weight) + + err = store.DeleteProviderKey(ctx, "openai", key.ID) + require.NoError(t, err) + + keys, err = store.GetProviderKeys(ctx, "openai") + require.NoError(t, err) + assert.Empty(t, keys) +} + +func TestProviderKeyCRUD_ProviderMustExist(t *testing.T) { + store := setupRDBTestStore(t) + ctx := context.Background() + + key := schemas.Key{ + ID: "key-uuid-1", + Name: "openai-primary", + Value: *schemas.NewEnvVar("sk-test-key-v1"), + Weight: 1.0, + } + + err := store.CreateProviderKey(ctx, "openai", key) + require.ErrorIs(t, err, ErrNotFound) + + _, err = store.GetProviderKeys(ctx, "openai") + require.ErrorIs(t, err, ErrNotFound) + + _, err = store.GetProviderKey(ctx, "openai", key.ID) + require.ErrorIs(t, err, ErrNotFound) + + err = store.UpdateProviderKey(ctx, "openai", key.ID, key) + require.ErrorIs(t, err, ErrNotFound) + + err = store.DeleteProviderKey(ctx, "openai", key.ID) + require.ErrorIs(t, err, ErrNotFound) +} + // ============================================================================= // Budget Tests // ============================================================================= @@ -440,24 +520,28 @@ func TestCreateVirtualKey_WithBudgetAndRateLimit(t *testing.T) { require.NoError(t, err) // Create virtual key with references - budgetID := "budget-for-vk" rateLimitID := "rate-limit-for-vk" + vkID := "vk-with-refs" vk := &tables.TableVirtualKey{ - ID: "vk-with-refs", + ID: vkID, Name: "VK With References", Value: "vk-refs-value", IsActive: true, - BudgetID: &budgetID, RateLimitID: &rateLimitID, } err = store.CreateVirtualKey(ctx, vk) require.NoError(t, err) + // Link the existing budget to the VK via FK + budget.VirtualKeyID = &vkID + err = store.UpdateBudget(ctx, budget) + require.NoError(t, err) + result, err := store.GetVirtualKey(ctx, "vk-with-refs") require.NoError(t, err) - assert.NotNil(t, result.BudgetID) - assert.Equal(t, "budget-for-vk", *result.BudgetID) + assert.Len(t, result.Budgets, 1) + assert.Equal(t, "budget-for-vk", result.Budgets[0].ID) assert.NotNil(t, result.RateLimitID) assert.Equal(t, "rate-limit-for-vk", *result.RateLimitID) } @@ -919,19 +1003,23 @@ func TestFullVirtualKeyFlow(t *testing.T) { require.NoError(t, err) // Step 4: Create virtual key - budgetID := "integration-budget" rateLimitID := "integration-rate-limit" + integrationVKID := "integration-vk" vk := &tables.TableVirtualKey{ - ID: "integration-vk", + ID: integrationVKID, Name: "Integration Virtual Key", Value: "vk-integration-xyz", IsActive: true, - BudgetID: &budgetID, RateLimitID: &rateLimitID, } err = store.CreateVirtualKey(ctx, vk) require.NoError(t, err) + // Link the existing budget to the VK via FK + budget.VirtualKeyID = &integrationVKID + err = store.UpdateBudget(ctx, budget) + require.NoError(t, err) + // Step 5: Create provider config with key reference weight := 1.0 pc := &tables.TableVirtualKeyProviderConfig{ @@ -949,7 +1037,7 @@ func TestFullVirtualKeyFlow(t *testing.T) { result, err := store.GetVirtualKey(ctx, "integration-vk") require.NoError(t, err) assert.Equal(t, "Integration Virtual Key", result.Name) - assert.NotNil(t, result.BudgetID) + assert.Len(t, result.Budgets, 1) assert.NotNil(t, result.RateLimitID) configs, err := store.GetVirtualKeyProviderConfigs(ctx, "integration-vk") diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 4d6d960bbd..bb4f54966b 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -62,6 +62,25 @@ type CustomersQueryParams struct { Search string } +// PricingOverrideFilters holds the filters for pricing overrides. +type PricingOverrideFilters struct { + ScopeKind *string + VirtualKeyID *string + ProviderID *string + ProviderKeyID *string +} + +// PricingOverridesQueryParams holds pagination, filtering, and search parameters for pricing override queries. +type PricingOverridesQueryParams struct { + Limit int + Offset int + Search string + ScopeKind *string + VirtualKeyID *string + ProviderID *string + ProviderKeyID *string +} + // ConfigStore is the interface for the config store. type ConfigStore interface { // Health check @@ -85,6 +104,11 @@ type ConfigStore interface { DeleteProvider(ctx context.Context, provider schemas.ModelProvider, tx ...*gorm.DB) error GetProvidersConfig(ctx context.Context) (map[schemas.ModelProvider]ProviderConfig, error) GetProviderConfig(ctx context.Context, provider schemas.ModelProvider) (*ProviderConfig, error) + GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) + GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) + CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error + UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error + DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error GetProviders(ctx context.Context) ([]tables.TableProvider, error) GetProvider(ctx context.Context, provider schemas.ModelProvider) (*tables.TableProvider, error) UpdateStatus(ctx context.Context, provider schemas.ModelProvider, keyID string, status, errorMsg string) error @@ -96,6 +120,7 @@ type ConfigStore interface { GetMCPClientsPaginated(ctx context.Context, params MCPClientsQueryParams) ([]tables.TableMCPClient, int64, error) CreateMCPClientConfig(ctx context.Context, clientConfig *schemas.MCPClientConfig) error UpdateMCPClientConfig(ctx context.Context, id string, clientConfig *tables.TableMCPClient) error + UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error DeleteMCPClientConfig(ctx context.Context, id string) error // Vector store config CRUD @@ -136,6 +161,8 @@ type ConfigStore interface { // Virtual key MCP config CRUD GetVirtualKeyMCPConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyMCPConfig, error) + GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) + GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error) CreateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error UpdateVirtualKeyMCPConfig(ctx context.Context, virtualKeyMCPConfig *tables.TableVirtualKeyMCPConfig, tx ...*gorm.DB) error DeleteVirtualKeyMCPConfig(ctx context.Context, id uint, tx ...*gorm.DB) error @@ -221,6 +248,14 @@ type ConfigStore interface { UpsertModelPrices(ctx context.Context, pricing *tables.TableModelPricing, tx ...*gorm.DB) error DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error + // Governance pricing overrides CRUD + GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) + GetPricingOverridesPaginated(ctx context.Context, params PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) + GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) + CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error + UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error + DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error + // Model parameters GetModelParameters(ctx context.Context, model string) (*tables.TableModelParameters, error) UpsertModelParameters(ctx context.Context, params *tables.TableModelParameters, tx ...*gorm.DB) error @@ -270,6 +305,55 @@ type ConfigStore interface { UpdateOauthToken(ctx context.Context, token *tables.TableOauthToken) error DeleteOauthToken(ctx context.Context, id string) error + // Per-user OAuth session CRUD + GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) + GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) + ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) + GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) + CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error + UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error + + // Per-user OAuth token CRUD + GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) + GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) + CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error + UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error + DeleteOauthUserToken(ctx context.Context, id string) error + DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error + + // Per-user OAuth Authorization Server CRUD (Bifrost as OAuth server) + GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) + CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error + GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) + GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) + CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error + UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error + DeletePerUserOAuthSession(ctx context.Context, id string) error + GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) + ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) + CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error + UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error + + // Per-user OAuth consent flow (pending flows before code issuance) + GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) + CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error + UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error + DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error + // ConsumePerUserOAuthPendingFlow atomically deletes a pending flow and returns the number of + // rows affected. Returns 0 if the flow was already consumed by a concurrent request. + ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) + // FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session, + // and creates the authorization code in a single transaction. Returns (0, nil) if the + // flow was already consumed by a concurrent request. + FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) + // GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID. + // Used during consent submit to discover which MCPs the user authenticated with. + // Queries tokens via upstream sessions matching the given gateway session ID. + GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) + // TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions + // (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID on each record. + TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error + // Not found retry wrapper RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) @@ -288,6 +372,7 @@ type ConfigStore interface { DeletePrompt(ctx context.Context, id string) error // Prompt Repository - Versions + GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) diff --git a/framework/configstore/tables/budget.go b/framework/configstore/tables/budget.go index e35c530f3d..2d7d397d26 100644 --- a/framework/configstore/tables/budget.go +++ b/framework/configstore/tables/budget.go @@ -15,9 +15,9 @@ type TableBudget struct { LastReset time.Time `gorm:"index" json:"last_reset"` // Last time budget was reset CurrentUsage float64 `gorm:"default:0" json:"current_usage"` // Current usage in dollars - // CalendarAligned snaps LastReset to the start of the current calendar period (day, week, month, year) - // instead of the exact creation/update time, so budgets reset at clean calendar boundaries. - CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` + // Owner FKs: a budget belongs to at most one VK or one ProviderConfig + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id,omitempty"` + ProviderConfigID *uint `gorm:"index" json:"provider_config_id,omitempty"` // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash @@ -32,6 +32,11 @@ func (TableBudget) TableName() string { return "governance_budgets" } // BeforeSave hook for Budget to validate reset duration format and max limit func (b *TableBudget) BeforeSave(tx *gorm.DB) error { + // A budget belongs to at most one owner type + if b.VirtualKeyID != nil && b.ProviderConfigID != nil { + return fmt.Errorf("budget cannot belong to both a virtual key and a provider config") + } + // Validate that ResetDuration is in correct format (e.g., "30s", "5m", "1h", "1d", "1w", "1M", "1Y") if d, err := ParseDuration(b.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", b.ResetDuration) diff --git a/framework/configstore/tables/clientconfig.go b/framework/configstore/tables/clientconfig.go index c1bc5c8935..a9ff7fc7f6 100644 --- a/framework/configstore/tables/clientconfig.go +++ b/framework/configstore/tables/clientconfig.go @@ -29,10 +29,12 @@ type TableClientConfig struct { MCPToolExecutionTimeout int `gorm:"default:30" json:"mcp_tool_execution_timeout"` // Timeout for individual tool execution in seconds (default: 30) MCPCodeModeBindingLevel string `gorm:"default:server" json:"mcp_code_mode_binding_level"` // How tools are exposed in VFS: "server" or "tool" MCPToolSyncInterval int `gorm:"default:10" json:"mcp_tool_sync_interval"` // Global tool sync interval in minutes (default: 10, 0 = disabled) + MCPDisableAutoToolInject bool `gorm:"default:false" json:"mcp_disable_auto_tool_inject"` // When true, MCP tools are not injected into requests by default AsyncJobResultTTL int `gorm:"default:3600" json:"async_job_result_ttl"` // Default TTL for async job results in seconds (default: 3600 = 1 hour) RequiredHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string LoggingHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string HideDeletedVirtualKeysInFilters bool `gorm:"default:false" json:"hide_deleted_virtual_keys_in_filters"` // Hide deleted virtual keys in logs filter dropdowns + RoutingChainMaxDepth int `gorm:"default:10" json:"routing_chain_max_depth"` // Maximum depth for routing rule chain evaluation (default: 10) WhitelistedRoutesJSON string `gorm:"type:text" json:"-"` // JSON serialized []string // LiteLLM fallback flag diff --git a/framework/configstore/tables/encryption_test.go b/framework/configstore/tables/encryption_test.go index 8807570e14..314b28be59 100644 --- a/framework/configstore/tables/encryption_test.go +++ b/framework/configstore/tables/encryption_test.go @@ -1,6 +1,7 @@ package tables import ( + "os" "testing" "time" @@ -175,12 +176,12 @@ func TestTableKey_BedrockFieldsEncryptDecrypt(t *testing.T) { Provider: "bedrock", KeyID: "bedrock-uuid-1", Value: *schemas.NewEnvVar("bedrock-val"), + Aliases: schemas.KeyAliases{"model-a": "profile-a"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ - AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), - SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), - Region: schemas.NewEnvVar("us-west-2"), - ARN: schemas.NewEnvVar("arn:aws:iam::123456789:role/test"), - Deployments: map[string]string{"model-a": "profile-a"}, + AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), + SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), + Region: schemas.NewEnvVar("us-west-2"), + ARN: schemas.NewEnvVar("arn:aws:iam::123456789:role/test"), BatchS3Config: &schemas.BatchS3Config{ Buckets: []schemas.S3BucketConfig{ {BucketName: "my-batch-bucket", Prefix: "jobs/", IsDefault: true}, @@ -197,9 +198,17 @@ func TestTableKey_BedrockFieldsEncryptDecrypt(t *testing.T) { assert.NotEqual(t, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", raw["bedrock_secret_key"]) assert.NotEqual(t, "us-west-2", raw["bedrock_region"]) assert.NotEqual(t, "arn:aws:iam::123456789:role/test", raw["bedrock_arn"]) - if rawDeploy, ok := raw["bedrock_deployments_json"].(string); ok { - assert.NotContains(t, rawDeploy, "profile-a") - } + rawAliasesVal := raw["aliases_json"] + require.NotNil(t, rawAliasesVal, "aliases_json should be present in raw row") + var rawAliasesStr string + switch v := rawAliasesVal.(type) { + case string: + rawAliasesStr = v + case []byte: + rawAliasesStr = string(v) + } + require.NotEmpty(t, rawAliasesStr, "aliases_json should not be empty") + assert.NotContains(t, rawAliasesStr, "profile-a") if rawBatch, ok := raw["bedrock_batch_s3_config_json"].(string); ok { assert.NotContains(t, rawBatch, "my-batch-bucket") } @@ -213,7 +222,7 @@ func TestTableKey_BedrockFieldsEncryptDecrypt(t *testing.T) { assert.Equal(t, "us-west-2", found.BedrockKeyConfig.Region.GetValue()) require.NotNil(t, found.BedrockKeyConfig.ARN) assert.Equal(t, "arn:aws:iam::123456789:role/test", found.BedrockKeyConfig.ARN.GetValue()) - assert.Equal(t, "profile-a", found.BedrockKeyConfig.Deployments["model-a"]) + assert.Equal(t, "profile-a", found.Aliases["model-a"]) require.NotNil(t, found.BedrockKeyConfig.BatchS3Config) require.Len(t, found.BedrockKeyConfig.BatchS3Config.Buckets, 1) assert.Equal(t, "my-batch-bucket", found.BedrockKeyConfig.BatchS3Config.Buckets[0].BucketName) @@ -1144,6 +1153,7 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { Provider: "custom", KeyID: "multi-uuid", Value: *schemas.NewEnvVar("multi-api-key"), + Aliases: schemas.KeyAliases{"claude-3": "profile-claude"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://azure.endpoint.com"), ClientID: schemas.NewEnvVar("multi-azure-cid"), @@ -1163,7 +1173,6 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { SessionToken: sessionToken, Region: schemas.NewEnvVar("eu-west-1"), ARN: schemas.NewEnvVar("arn:aws:bedrock:eu-west-1:123:role"), - Deployments: map[string]string{"claude-3": "profile-claude"}, }, } @@ -1180,9 +1189,17 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { assert.NotEqual(t, "us-central1", raw["vertex_region"]) assert.NotEqual(t, "eu-west-1", raw["bedrock_region"]) assert.NotEqual(t, "arn:aws:bedrock:eu-west-1:123:role", raw["bedrock_arn"]) - if rawDeploy, ok := raw["bedrock_deployments_json"].(string); ok { - assert.NotContains(t, rawDeploy, "profile-claude") - } + rawAliasesVal2 := raw["aliases_json"] + require.NotNil(t, rawAliasesVal2, "aliases_json should be present in raw row") + var rawAliasesStr2 string + switch v := rawAliasesVal2.(type) { + case string: + rawAliasesStr2 = v + case []byte: + rawAliasesStr2 = string(v) + } + require.NotEmpty(t, rawAliasesStr2, "aliases_json should not be empty") + assert.NotContains(t, rawAliasesStr2, "profile-claude") var found TableKey require.NoError(t, db.First(&found, key.ID).Error) @@ -1214,7 +1231,7 @@ func TestTableKey_AllProviderConfigs_EncryptDecrypt(t *testing.T) { assert.Equal(t, "eu-west-1", found.BedrockKeyConfig.Region.GetValue()) require.NotNil(t, found.BedrockKeyConfig.ARN) assert.Equal(t, "arn:aws:bedrock:eu-west-1:123:role", found.BedrockKeyConfig.ARN.GetValue()) - assert.Equal(t, "profile-claude", found.BedrockKeyConfig.Deployments["claude-3"]) + assert.Equal(t, "profile-claude", found.Aliases["claude-3"]) } // ============================================================================ @@ -1268,9 +1285,9 @@ func TestTableMCPClient_EncryptionDisabled_StoresPlaintext(t *testing.T) { db := setupTestDB(t) client := &TableMCPClient{ - ClientID: "mcp-dis-1", - Name: "disabled-mcp", - ConnectionType: "sse", + ClientID: "mcp-dis-1", + Name: "disabled-mcp", + ConnectionType: "sse", ConnectionString: schemas.NewEnvVar("https://mcp.example.com"), Headers: map[string]schemas.EnvVar{ "Authorization": *schemas.NewEnvVar("Bearer secret-token"), @@ -1708,3 +1725,258 @@ func TestPostgres_EncryptedColumns_AreText(t *testing.T) { }) } } + +// ============================================================================ +// Env-var-reference persistence regression tests +// +// These tests guard against a class of bugs where BeforeSave used GetValue() != "" +// to decide whether to persist a config field. When a field was set via env var +// reference (e.g. "env.AZURE_ENDPOINT") and the env var was not set on the server, +// GetValue() would return "" and the field β€” including the env reference β€” would be +// dropped from the DB. On the next reload the entire provider-specific config block +// could vanish. +// +// IsSet() (which checks both Val and EnvVar) is the correct check, and AfterFind +// reconstruction must consider all fields in the config, not just one. +// ============================================================================ + +// TestTableKey_VertexUnresolvedEnvVar_RoundTrip verifies that a Vertex key configured +// with an env var reference for ProjectID survives the BeforeSave/AfterFind round-trip +// even when the env var is NOT set on the server (so the resolved Val is empty). +func TestTableKey_VertexUnresolvedEnvVar_RoundTrip(t *testing.T) { + // Make sure the env var is NOT set so the resolved Val is empty. + require.NoError(t, os.Unsetenv("FAKE_VERTEX_PROJECT_ID_FOR_TEST")) + + db := setupTestDB(t) + + key := &TableKey{ + Name: "vertex-unresolved-env", + ProviderID: 1, + Provider: "vertex", + KeyID: "vertex-env-uuid-1", + Value: *schemas.NewEnvVar(""), + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: schemas.EnvVar{ + Val: "", + EnvVar: "env.FAKE_VERTEX_PROJECT_ID_FOR_TEST", + FromEnv: true, + }, + Region: *schemas.NewEnvVar("us-central1"), + }, + } + + require.NoError(t, db.Create(key).Error) + + // Read back through GORM (triggers AfterFind reconstruction). + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + // VertexKeyConfig must NOT be wiped β€” this was the original bug. + require.NotNil(t, found.VertexKeyConfig, "VertexKeyConfig was wiped on reload") + assert.Equal(t, "env.FAKE_VERTEX_PROJECT_ID_FOR_TEST", found.VertexKeyConfig.ProjectID.EnvVar, + "env var reference for ProjectID lost on round-trip") + assert.True(t, found.VertexKeyConfig.ProjectID.FromEnv, + "FromEnv flag for ProjectID lost on round-trip") + assert.Equal(t, "us-central1", found.VertexKeyConfig.Region.GetValue(), + "Plain Region value should survive round-trip unchanged") +} + +// TestTableKey_AzureUnresolvedEnvVar_RoundTrip verifies the same property for Azure. +// This also exercises the broadened AfterFind reconstruction condition: when only the +// endpoint is set (and unresolved), the entire AzureKeyConfig must still be reconstructed. +func TestTableKey_AzureUnresolvedEnvVar_RoundTrip(t *testing.T) { + require.NoError(t, os.Unsetenv("FAKE_AZURE_ENDPOINT_FOR_TEST")) + + db := setupTestDB(t) + + key := &TableKey{ + Name: "azure-unresolved-env", + ProviderID: 1, + Provider: "azure", + KeyID: "azure-env-uuid-1", + Value: *schemas.NewEnvVar(""), + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: schemas.EnvVar{ + Val: "", + EnvVar: "env.FAKE_AZURE_ENDPOINT_FOR_TEST", + FromEnv: true, + }, + }, + } + + require.NoError(t, db.Create(key).Error) + + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + require.NotNil(t, found.AzureKeyConfig, "AzureKeyConfig was wiped on reload") + assert.Equal(t, "env.FAKE_AZURE_ENDPOINT_FOR_TEST", found.AzureKeyConfig.Endpoint.EnvVar, + "env var reference for Endpoint lost on round-trip") + assert.True(t, found.AzureKeyConfig.Endpoint.FromEnv, + "FromEnv flag for Endpoint lost on round-trip") +} + +// TestTableKey_AzureOnlyApiVersion_AfterFindReconstructs verifies that AzureKeyConfig +// is reconstructed from the DB even when ONLY a non-endpoint Azure field is set. +// Before the fix, AfterFind only checked AzureEndpoint != nil and would silently drop +// the entire Azure config when only api_version (or any other Azure field) was present. +func TestTableKey_AzureOnlyApiVersion_AfterFindReconstructs(t *testing.T) { + db := setupTestDB(t) + + apiVersion := schemas.NewEnvVar("2024-10-21") + key := &TableKey{ + Name: "azure-only-apiversion", + ProviderID: 1, + Provider: "azure", + KeyID: "azure-apiver-uuid-1", + Value: *schemas.NewEnvVar(""), + AzureKeyConfig: &schemas.AzureKeyConfig{ + // No endpoint, no client id β€” only api_version. + APIVersion: apiVersion, + }, + } + + require.NoError(t, db.Create(key).Error) + + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + require.NotNil(t, found.AzureKeyConfig, + "AzureKeyConfig should be reconstructed when only api_version is present") + require.NotNil(t, found.AzureKeyConfig.APIVersion) + assert.Equal(t, "2024-10-21", found.AzureKeyConfig.APIVersion.GetValue()) +} + +// TestTableKey_BedrockUnresolvedEnvVar_RoundTrip verifies the same property for +// Bedrock explicit credentials. +func TestTableKey_BedrockUnresolvedEnvVar_RoundTrip(t *testing.T) { + require.NoError(t, os.Unsetenv("FAKE_AWS_ACCESS_KEY_FOR_TEST")) + require.NoError(t, os.Unsetenv("FAKE_AWS_SECRET_KEY_FOR_TEST")) + + db := setupTestDB(t) + + key := &TableKey{ + Name: "bedrock-unresolved-env", + ProviderID: 1, + Provider: "bedrock", + KeyID: "bedrock-env-uuid-1", + Value: *schemas.NewEnvVar(""), + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + AccessKey: schemas.EnvVar{ + Val: "", + EnvVar: "env.FAKE_AWS_ACCESS_KEY_FOR_TEST", + FromEnv: true, + }, + SecretKey: schemas.EnvVar{ + Val: "", + EnvVar: "env.FAKE_AWS_SECRET_KEY_FOR_TEST", + FromEnv: true, + }, + Region: schemas.NewEnvVar("us-west-2"), + }, + } + + require.NoError(t, db.Create(key).Error) + + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + require.NotNil(t, found.BedrockKeyConfig, "BedrockKeyConfig was wiped on reload") + assert.Equal(t, "env.FAKE_AWS_ACCESS_KEY_FOR_TEST", found.BedrockKeyConfig.AccessKey.EnvVar, + "env var reference for AccessKey lost on round-trip") + assert.Equal(t, "env.FAKE_AWS_SECRET_KEY_FOR_TEST", found.BedrockKeyConfig.SecretKey.EnvVar, + "env var reference for SecretKey lost on round-trip") + require.NotNil(t, found.BedrockKeyConfig.Region) + assert.Equal(t, "us-west-2", found.BedrockKeyConfig.Region.GetValue()) +} + +// TestTableKey_OllamaUnresolvedEnvVar_RoundTrip and TestTableKey_SGLUnresolvedEnvVar_RoundTrip +// verify the same property for the recently-added providers, which also use env-aware persistence. +func TestTableKey_OllamaUnresolvedEnvVar_RoundTrip(t *testing.T) { + require.NoError(t, os.Unsetenv("FAKE_OLLAMA_URL_FOR_TEST")) + + db := setupTestDB(t) + + key := &TableKey{ + Name: "ollama-unresolved-env", + ProviderID: 1, + Provider: "ollama", + KeyID: "ollama-env-uuid-1", + Value: *schemas.NewEnvVar(""), + OllamaKeyConfig: &schemas.OllamaKeyConfig{ + URL: schemas.EnvVar{ + Val: "", + EnvVar: "env.FAKE_OLLAMA_URL_FOR_TEST", + FromEnv: true, + }, + }, + } + + require.NoError(t, db.Create(key).Error) + + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + require.NotNil(t, found.OllamaKeyConfig, "OllamaKeyConfig was wiped on reload") + assert.Equal(t, "env.FAKE_OLLAMA_URL_FOR_TEST", found.OllamaKeyConfig.URL.EnvVar) + assert.True(t, found.OllamaKeyConfig.URL.FromEnv) +} + +func TestTableKey_SGLUnresolvedEnvVar_RoundTrip(t *testing.T) { + require.NoError(t, os.Unsetenv("FAKE_SGL_URL_FOR_TEST")) + + db := setupTestDB(t) + + key := &TableKey{ + Name: "sgl-unresolved-env", + ProviderID: 1, + Provider: "sgl", + KeyID: "sgl-env-uuid-1", + Value: *schemas.NewEnvVar(""), + SGLKeyConfig: &schemas.SGLKeyConfig{ + URL: schemas.EnvVar{ + Val: "", + EnvVar: "env.FAKE_SGL_URL_FOR_TEST", + FromEnv: true, + }, + }, + } + + require.NoError(t, db.Create(key).Error) + + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + require.NotNil(t, found.SGLKeyConfig, "SGLKeyConfig was wiped on reload") + assert.Equal(t, "env.FAKE_SGL_URL_FOR_TEST", found.SGLKeyConfig.URL.EnvVar) + assert.True(t, found.SGLKeyConfig.URL.FromEnv) +} + +// TestTableKey_VertexPlainValue_RoundTrip is a sanity check ensuring that plain +// (non-env-backed) values still round-trip cleanly through the persistence layer +// after the IsSet() change. Both branches of the BeforeSave check matter. +func TestTableKey_VertexPlainValue_RoundTrip(t *testing.T) { + db := setupTestDB(t) + + key := &TableKey{ + Name: "vertex-plain", + ProviderID: 1, + Provider: "vertex", + KeyID: "vertex-plain-uuid-1", + Value: *schemas.NewEnvVar(""), + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("my-gcp-project"), + Region: *schemas.NewEnvVar("us-central1"), + }, + } + + require.NoError(t, db.Create(key).Error) + + var found TableKey + require.NoError(t, db.First(&found, key.ID).Error) + + require.NotNil(t, found.VertexKeyConfig) + assert.Equal(t, "my-gcp-project", found.VertexKeyConfig.ProjectID.GetValue()) + assert.False(t, found.VertexKeyConfig.ProjectID.FromEnv) + assert.Equal(t, "us-central1", found.VertexKeyConfig.Region.GetValue()) +} diff --git a/framework/configstore/tables/key.go b/framework/configstore/tables/key.go index c0763e9045..e790f73a00 100644 --- a/framework/configstore/tables/key.go +++ b/framework/configstore/tables/key.go @@ -29,21 +29,22 @@ type TableKey struct { // Config hash is used to detect changes synced from config.json file ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` + // Unified aliases + AliasesJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.KeyAliases + // Azure config fields (embedded instead of separate table for simplicity) - AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"` - AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"` - AzureDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"` - AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"` - AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"` - AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string + AzureEndpoint *schemas.EnvVar `gorm:"type:text" json:"azure_endpoint,omitempty"` + AzureAPIVersion *schemas.EnvVar `gorm:"type:text" json:"azure_api_version,omitempty"` + AzureClientID *schemas.EnvVar `gorm:"type:text" json:"azure_client_id,omitempty"` + AzureClientSecret *schemas.EnvVar `gorm:"type:text" json:"azure_client_secret,omitempty"` + AzureTenantID *schemas.EnvVar `gorm:"type:text" json:"azure_tenant_id,omitempty"` + AzureScopesJSON *string `gorm:"column:azure_scopes;type:text" json:"-"` // JSON serialized []string // Vertex config fields (embedded) VertexProjectID *schemas.EnvVar `gorm:"type:text" json:"vertex_project_id,omitempty"` VertexProjectNumber *schemas.EnvVar `gorm:"type:text" json:"vertex_project_number,omitempty"` VertexRegion *schemas.EnvVar `gorm:"type:text" json:"vertex_region,omitempty"` VertexAuthCredentials *schemas.EnvVar `gorm:"type:text" json:"vertex_auth_credentials,omitempty"` - VertexDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // Bedrock config fields (embedded) BedrockAccessKey *schemas.EnvVar `gorm:"type:text" json:"bedrock_access_key,omitempty"` @@ -54,16 +55,21 @@ type TableKey struct { BedrockRoleARN *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_arn,omitempty"` BedrockExternalID *schemas.EnvVar `gorm:"type:text" json:"bedrock_external_id,omitempty"` BedrockRoleSessionName *schemas.EnvVar `gorm:"type:text" json:"bedrock_role_session_name,omitempty"` - BedrockDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string BedrockBatchS3ConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.BatchS3Config - // Replicate config fields (embedded) - ReplicateDeploymentsJSON *string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - // VLLM config fields (embedded) VLLMUrl *schemas.EnvVar `gorm:"type:text" json:"vllm_url,omitempty"` VLLMModelName *string `gorm:"type:varchar(255)" json:"vllm_model_name,omitempty"` + // Replicate config fields (embedded) + ReplicateUseDeploymentsEndpoint *bool `gorm:"column:replicate_use_deployments_endpoint" json:"replicate_use_deployments_endpoint,omitempty"` + + // Ollama config fields (embedded) + OllamaUrl *schemas.EnvVar `gorm:"type:text" json:"ollama_url,omitempty"` + + // SGL config fields (embedded) + SGLUrl *schemas.EnvVar `gorm:"type:text" json:"sgl_url,omitempty"` + // Batch API configuration UseForBatchAPI *bool `gorm:"default:false" json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations @@ -73,13 +79,16 @@ type TableKey struct { EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` // Virtual fields for runtime use (not stored in DB) - Models []string `gorm:"-" json:"models"` - BlacklistedModels []string `gorm:"-" json:"blacklisted_models"` + Models schemas.WhiteList `gorm:"-" json:"models"` // ["*"] allows all models; empty denies all (deny-by-default) + BlacklistedModels schemas.BlackList `gorm:"-" json:"blacklisted_models"` + Aliases schemas.KeyAliases `gorm:"-" json:"aliases,omitempty"` AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"` VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"` BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` - ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"` VLLMKeyConfig *schemas.VLLMKeyConfig `gorm:"-" json:"vllm_key_config,omitempty"` + ReplicateKeyConfig *schemas.ReplicateKeyConfig `gorm:"-" json:"replicate_key_config,omitempty"` + OllamaKeyConfig *schemas.OllamaKeyConfig `gorm:"-" json:"ollama_key_config,omitempty"` + SGLKeyConfig *schemas.SGLKeyConfig `gorm:"-" json:"sgl_key_config,omitempty"` } // TableName sets the table name for each model @@ -91,24 +100,22 @@ func (TableKey) TableName() string { return "config_keys" } // batch S3 config) before writing to the database. Encryption runs last to ensure it // operates on the final serialized values. func (k *TableKey) BeforeSave(tx *gorm.DB) error { - if k.Models != nil { - data, err := json.Marshal(k.Models) - if err != nil { - return err - } - k.ModelsJSON = string(data) - } else { - k.ModelsJSON = "[]" + if err := k.Models.Validate(); err != nil { + return err } - if k.BlacklistedModels != nil { - data, err := json.Marshal(k.BlacklistedModels) - if err != nil { - return err - } - k.BlacklistedModelsJSON = string(data) - } else { - k.BlacklistedModelsJSON = "[]" + data, err := json.Marshal(k.Models) + if err != nil { + return err + } + k.ModelsJSON = string(data) + if err := k.BlacklistedModels.Validate(); err != nil { + return err } + data, err = json.Marshal(k.BlacklistedModels) + if err != nil { + return err + } + k.BlacklistedModelsJSON = string(data) if k.Enabled == nil { enabled := true // DB default k.Enabled = &enabled @@ -123,7 +130,7 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { // shared pointer, the caller's in-memory config is silently corrupted. // See: TestBeforeSave_DoesNotMutateSharedProviderConfigs if k.AzureKeyConfig != nil { - if k.AzureKeyConfig.Endpoint.GetValue() != "" { + if k.AzureKeyConfig.Endpoint.IsSet() { ep := k.AzureKeyConfig.Endpoint k.AzureEndpoint = &ep } else { @@ -163,76 +170,54 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { } else { k.AzureScopesJSON = nil } - if k.AzureKeyConfig.Deployments != nil { - data, err := json.Marshal(k.AzureKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.AzureDeploymentsJSON = &s - } else { - k.AzureDeploymentsJSON = nil - } } else { k.AzureEndpoint = nil k.AzureAPIVersion = nil - k.AzureDeploymentsJSON = nil k.AzureClientID = nil k.AzureClientSecret = nil k.AzureTenantID = nil k.AzureScopesJSON = nil } if k.VertexKeyConfig != nil { - if k.VertexKeyConfig.ProjectID.GetValue() != "" { + if k.VertexKeyConfig.ProjectID.IsSet() { pid := k.VertexKeyConfig.ProjectID k.VertexProjectID = &pid } else { k.VertexProjectID = nil } - if k.VertexKeyConfig.ProjectNumber.GetValue() != "" { + if k.VertexKeyConfig.ProjectNumber.IsSet() { pn := k.VertexKeyConfig.ProjectNumber k.VertexProjectNumber = &pn } else { k.VertexProjectNumber = nil } - if k.VertexKeyConfig.Region.GetValue() != "" { + if k.VertexKeyConfig.Region.IsSet() { vr := k.VertexKeyConfig.Region k.VertexRegion = &vr } else { k.VertexRegion = nil } - if k.VertexKeyConfig.AuthCredentials.GetValue() != "" { + if k.VertexKeyConfig.AuthCredentials.IsSet() { ac := k.VertexKeyConfig.AuthCredentials k.VertexAuthCredentials = &ac } else { k.VertexAuthCredentials = nil } - if k.VertexKeyConfig.Deployments != nil { - data, err := json.Marshal(k.VertexKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.VertexDeploymentsJSON = &s - } else { - k.VertexDeploymentsJSON = nil - } } else { k.VertexProjectID = nil k.VertexProjectNumber = nil k.VertexRegion = nil k.VertexAuthCredentials = nil - k.VertexDeploymentsJSON = nil } if k.BedrockKeyConfig != nil { - if k.BedrockKeyConfig.AccessKey.GetValue() != "" { + if k.BedrockKeyConfig.AccessKey.IsSet() { // Copy to avoid encrypting the shared BedrockKeyConfig through the pointer ak := k.BedrockKeyConfig.AccessKey k.BedrockAccessKey = &ak } else { k.BedrockAccessKey = nil } - if k.BedrockKeyConfig.SecretKey.GetValue() != "" { + if k.BedrockKeyConfig.SecretKey.IsSet() { // Copy to avoid encrypting the shared BedrockKeyConfig through the pointer sk := k.BedrockKeyConfig.SecretKey k.BedrockSecretKey = &sk @@ -276,16 +261,6 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { } else { k.BedrockRoleSessionName = nil } - if k.BedrockKeyConfig.Deployments != nil { - data, err := sonic.Marshal(k.BedrockKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.BedrockDeploymentsJSON = &s - } else { - k.BedrockDeploymentsJSON = nil - } if k.BedrockKeyConfig.BatchS3Config != nil { data, err := sonic.Marshal(k.BedrockKeyConfig.BatchS3Config) if err != nil { @@ -305,27 +280,25 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.BedrockRoleARN = nil k.BedrockExternalID = nil k.BedrockRoleSessionName = nil - k.BedrockDeploymentsJSON = nil k.BedrockBatchS3ConfigJSON = nil } - if k.ReplicateKeyConfig != nil { - if k.ReplicateKeyConfig.Deployments != nil { - data, err := sonic.Marshal(k.ReplicateKeyConfig.Deployments) - if err != nil { - return err - } - s := string(data) - k.ReplicateDeploymentsJSON = &s - } else { - k.ReplicateDeploymentsJSON = nil + if k.Aliases != nil { + if err := k.Aliases.Validate(); err != nil { + return err + } + data, err := sonic.Marshal(k.Aliases) + if err != nil { + return err } + s := string(data) + k.AliasesJSON = &s } else { - k.ReplicateDeploymentsJSON = nil + k.AliasesJSON = nil } if k.VLLMKeyConfig != nil { - if k.VLLMKeyConfig.URL.GetValue() != "" { + if k.VLLMKeyConfig.URL.IsSet() { u := k.VLLMKeyConfig.URL // Value-copy to prevent shared pointer mutation k.VLLMUrl = &u } else { @@ -342,6 +315,27 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.VLLMModelName = nil } + if k.ReplicateKeyConfig != nil { + v := k.ReplicateKeyConfig.UseDeploymentsEndpoint + k.ReplicateUseDeploymentsEndpoint = &v + } else { + k.ReplicateUseDeploymentsEndpoint = nil + } + + if k.OllamaKeyConfig != nil && k.OllamaKeyConfig.URL.IsSet() { + u := k.OllamaKeyConfig.URL + k.OllamaUrl = &u + } else { + k.OllamaUrl = nil + } + + if k.SGLKeyConfig != nil && k.SGLKeyConfig.URL.IsSet() { + u := k.SGLKeyConfig.URL + k.SGLUrl = &u + } else { + k.SGLUrl = nil + } + // Encrypt sensitive fields after serialization if encrypt.IsEnabled() { if err := encryptEnvVar(&k.Value); err != nil { @@ -401,16 +395,25 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { if err := encryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil { return fmt.Errorf("failed to encrypt bedrock role session name: %w", err) } - if err := encryptString(k.BedrockDeploymentsJSON); err != nil { - return fmt.Errorf("failed to encrypt bedrock deployments: %w", err) - } if err := encryptString(k.BedrockBatchS3ConfigJSON); err != nil { return fmt.Errorf("failed to encrypt bedrock batch s3 config: %w", err) } + // Aliases + if err := encryptString(k.AliasesJSON); err != nil { + return fmt.Errorf("failed to encrypt aliases: %w", err) + } // VLLM if err := encryptEnvVarPtr(&k.VLLMUrl); err != nil { return fmt.Errorf("failed to encrypt vllm url: %w", err) } + // Ollama + if err := encryptEnvVarPtr(&k.OllamaUrl); err != nil { + return fmt.Errorf("failed to encrypt ollama url: %w", err) + } + // SGL + if err := encryptEnvVarPtr(&k.SGLUrl); err != nil { + return fmt.Errorf("failed to encrypt sgl url: %w", err) + } k.EncryptionStatus = EncryptionStatusEncrypted } return nil @@ -479,31 +482,36 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { if err := decryptEnvVarPtr(&k.BedrockRoleSessionName); err != nil { return fmt.Errorf("failed to decrypt bedrock role session name: %w", err) } - if err := decryptString(k.BedrockDeploymentsJSON); err != nil { - return fmt.Errorf("failed to decrypt bedrock deployments: %w", err) - } if err := decryptString(k.BedrockBatchS3ConfigJSON); err != nil { return fmt.Errorf("failed to decrypt bedrock batch s3 config: %w", err) } + // Aliases + if err := decryptString(k.AliasesJSON); err != nil { + return fmt.Errorf("failed to decrypt aliases: %w", err) + } // VLLM if err := decryptEnvVarPtr(&k.VLLMUrl); err != nil { return fmt.Errorf("failed to decrypt vllm url: %w", err) } + // Ollama + if err := decryptEnvVarPtr(&k.OllamaUrl); err != nil { + return fmt.Errorf("failed to decrypt ollama url: %w", err) + } + // SGL + if err := decryptEnvVarPtr(&k.SGLUrl); err != nil { + return fmt.Errorf("failed to decrypt sgl url: %w", err) + } } if k.ModelsJSON != "" { if err := json.Unmarshal([]byte(k.ModelsJSON), &k.Models); err != nil { return err } - } else { - k.Models = []string{} } if k.BlacklistedModelsJSON != "" { if err := json.Unmarshal([]byte(k.BlacklistedModelsJSON), &k.BlacklistedModels); err != nil { return err } - } else { - k.BlacklistedModels = []string{} } if k.Enabled == nil { enabled := true // DB default @@ -514,7 +522,7 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { k.UseForBatchAPI = &useForBatchAPI } // Reconstruct Azure config if fields are present - if k.AzureEndpoint != nil { + if k.AzureEndpoint != nil || k.AzureAPIVersion != nil || k.AzureClientID != nil || k.AzureClientSecret != nil || k.AzureTenantID != nil || (k.AzureScopesJSON != nil && *k.AzureScopesJSON != "") { var scopes []string if k.AzureScopesJSON != nil && *k.AzureScopesJSON != "" { if err := json.Unmarshal([]byte(*k.AzureScopesJSON), &scopes); err != nil { @@ -534,20 +542,10 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { azureConfig.Endpoint = *k.AzureEndpoint } - if k.AzureDeploymentsJSON != nil { - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.AzureDeploymentsJSON), &deployments); err != nil { - return err - } - azureConfig.Deployments = deployments - } else { - azureConfig.Deployments = nil - } - k.AzureKeyConfig = azureConfig } // Reconstruct Vertex config if fields are present - if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil || (k.VertexDeploymentsJSON != nil && *k.VertexDeploymentsJSON != "") { + if k.VertexProjectID != nil || k.VertexProjectNumber != nil || k.VertexRegion != nil || k.VertexAuthCredentials != nil { config := &schemas.VertexKeyConfig{} if k.VertexProjectID != nil { @@ -564,20 +562,10 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { if k.VertexAuthCredentials != nil { config.AuthCredentials = *k.VertexAuthCredentials } - if k.VertexDeploymentsJSON != nil { - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.VertexDeploymentsJSON), &deployments); err != nil { - return err - } - config.Deployments = deployments - } else { - config.Deployments = nil - } - k.VertexKeyConfig = config } // Reconstruct Bedrock config if fields are present - if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockDeploymentsJSON != nil && *k.BedrockDeploymentsJSON != "") || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") { + if k.BedrockAccessKey != nil || k.BedrockSecretKey != nil || k.BedrockSessionToken != nil || k.BedrockRegion != nil || k.BedrockARN != nil || k.BedrockRoleARN != nil || k.BedrockExternalID != nil || k.BedrockRoleSessionName != nil || (k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "") { bedrockConfig := &schemas.BedrockKeyConfig{} if k.BedrockAccessKey != nil { @@ -595,16 +583,6 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { bedrockConfig.SecretKey = *k.BedrockSecretKey } - if k.BedrockDeploymentsJSON != nil { - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.BedrockDeploymentsJSON), &deployments); err != nil { - return err - } - bedrockConfig.Deployments = deployments - } else { - bedrockConfig.Deployments = nil - } - if k.BedrockBatchS3ConfigJSON != nil && *k.BedrockBatchS3ConfigJSON != "" { var batchS3Config schemas.BatchS3Config if err := json.Unmarshal([]byte(*k.BedrockBatchS3ConfigJSON), &batchS3Config); err != nil { @@ -615,15 +593,15 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { k.BedrockKeyConfig = bedrockConfig } - // Reconstruct Replicate config if fields are present - if k.ReplicateDeploymentsJSON != nil && *k.ReplicateDeploymentsJSON != "" { - replicateConfig := &schemas.ReplicateKeyConfig{} - var deployments map[string]string - if err := json.Unmarshal([]byte(*k.ReplicateDeploymentsJSON), &deployments); err != nil { + // Reconstruct Aliases + if k.AliasesJSON != nil && *k.AliasesJSON != "" { + var aliases schemas.KeyAliases + if err := sonic.Unmarshal([]byte(*k.AliasesJSON), &aliases); err != nil { return err } - replicateConfig.Deployments = deployments - k.ReplicateKeyConfig = replicateConfig + k.Aliases = aliases + } else { + k.Aliases = nil } // Reconstruct VLLM config if fields are present if k.VLLMUrl != nil || (k.VLLMModelName != nil && *k.VLLMModelName != "") { @@ -638,5 +616,29 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { } else { k.VLLMKeyConfig = nil } + // Reconstruct Replicate config if fields are present + if k.ReplicateUseDeploymentsEndpoint != nil { + k.ReplicateKeyConfig = &schemas.ReplicateKeyConfig{ + UseDeploymentsEndpoint: *k.ReplicateUseDeploymentsEndpoint, + } + } else { + k.ReplicateKeyConfig = nil + } + // Reconstruct Ollama config if fields are present + if k.OllamaUrl != nil { + k.OllamaKeyConfig = &schemas.OllamaKeyConfig{ + URL: *k.OllamaUrl, + } + } else { + k.OllamaKeyConfig = nil + } + // Reconstruct SGL config if fields are present + if k.SGLUrl != nil { + k.SGLKeyConfig = &schemas.SGLKeyConfig{ + URL: *k.SGLUrl, + } + } else { + k.SGLKeyConfig = nil + } return nil } diff --git a/framework/configstore/tables/mcp.go b/framework/configstore/tables/mcp.go index 38a5a6b942..6e844a073c 100644 --- a/framework/configstore/tables/mcp.go +++ b/framework/configstore/tables/mcp.go @@ -13,25 +13,32 @@ import ( // TableMCPClient represents an MCP client configuration in the database type TableMCPClient struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. - ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` - Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` - IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client - ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType - ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"` - StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig - ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string - HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string - IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks - ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64 - ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled) + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` // ID is used as the internal primary key and is also accessed by public methods, so it must be present. + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + Name string `gorm:"type:varchar(255);uniqueIndex;not null" json:"name"` + IsCodeModeClient bool `gorm:"default:false" json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType string `gorm:"type:varchar(20);not null" json:"connection_type"` // schemas.MCPConnectionType + ConnectionString *schemas.EnvVar `gorm:"type:text" json:"connection_string,omitempty"` + StdioConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized schemas.MCPStdioConfig + ToolsToExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + ToolsToAutoExecuteJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + HeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string + AllowedExtraHeadersJSON string `gorm:"type:text" json:"-"` // JSON serialized []string + IsPingAvailable *bool `gorm:"default:true" json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks + ToolPricingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]float64 + ToolSyncInterval int `gorm:"default:0" json:"tool_sync_interval"` // Per-client tool sync interval in minutes (0 = use global, -1 = disabled) + + // Per-user OAuth: discovered tools persisted so they survive restart + DiscoveredToolsJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]schemas.ChatTool + ToolNameMappingJSON string `gorm:"type:text" json:"-"` // JSON serialized map[string]string // OAuth authentication fields AuthType string `gorm:"type:varchar(20);default:'headers'" json:"auth_type"` // "none", "headers", "oauth" OauthConfigID *string `gorm:"type:varchar(255);index;constraint:OnDelete:CASCADE" json:"oauth_config_id"` // Foreign key to oauth_configs.ID with CASCADE delete OauthConfig *TableOauthConfig `gorm:"foreignKey:OauthConfigID;references:ID;constraint:OnDelete:CASCADE" json:"-"` // Gorm relationship + AllowOnAllVirtualKeys bool `gorm:"default:false" json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys + // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash"` @@ -42,11 +49,14 @@ type TableMCPClient struct { UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` // Virtual fields for runtime use (not stored in DB) - StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` - ToolsToExecute []string `gorm:"-" json:"tools_to_execute"` - ToolsToAutoExecute []string `gorm:"-" json:"tools_to_auto_execute"` - Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"` - ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"` + StdioConfig *schemas.MCPStdioConfig `gorm:"-" json:"stdio_config,omitempty"` + ToolsToExecute schemas.WhiteList `gorm:"-" json:"tools_to_execute"` + ToolsToAutoExecute schemas.WhiteList `gorm:"-" json:"tools_to_auto_execute"` + Headers map[string]schemas.EnvVar `gorm:"-" json:"headers"` + AllowedExtraHeaders schemas.WhiteList `gorm:"-" json:"allowed_extra_headers"` + ToolPricing map[string]float64 `gorm:"-" json:"tool_pricing"` + DiscoveredTools map[string]schemas.ChatTool `gorm:"-" json:"-"` + DiscoveredToolNameMapping map[string]string `gorm:"-" json:"-"` } // TableName sets the table name for each model @@ -68,6 +78,9 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { } if c.ToolsToExecute != nil { + if err := c.ToolsToExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_execute: %w", err) + } data, err := json.Marshal(c.ToolsToExecute) if err != nil { return err @@ -78,6 +91,9 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { } if c.ToolsToAutoExecute != nil { + if err := c.ToolsToAutoExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_auto_execute: %w", err) + } data, err := json.Marshal(c.ToolsToAutoExecute) if err != nil { return err @@ -105,6 +121,19 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { c.HeadersJSON = "{}" } + if c.AllowedExtraHeaders != nil { + if err := c.AllowedExtraHeaders.Validate(); err != nil { + return fmt.Errorf("invalid allowed_extra_headers: %w", err) + } + data, err := json.Marshal(c.AllowedExtraHeaders) + if err != nil { + return err + } + c.AllowedExtraHeadersJSON = string(data) + } else { + c.AllowedExtraHeadersJSON = "[]" + } + if c.ToolPricing != nil { data, err := json.Marshal(c.ToolPricing) if err != nil { @@ -115,6 +144,22 @@ func (c *TableMCPClient) BeforeSave(tx *gorm.DB) error { c.ToolPricingJSON = "{}" } + if c.DiscoveredTools != nil { + data, err := json.Marshal(c.DiscoveredTools) + if err != nil { + return err + } + c.DiscoveredToolsJSON = string(data) + } + + if c.DiscoveredToolNameMapping != nil { + data, err := json.Marshal(c.DiscoveredToolNameMapping) + if err != nil { + return err + } + c.ToolNameMappingJSON = string(data) + } + // Encrypt sensitive fields after serialization. // Always set EncryptionStatus when encryption is enabled so the startup // batch pass does not re-process this row indefinitely. @@ -183,11 +228,25 @@ func (c *TableMCPClient) AfterFind(tx *gorm.DB) error { return err } } - + if c.AllowedExtraHeadersJSON != "" { + if err := sonic.Unmarshal([]byte(c.AllowedExtraHeadersJSON), &c.AllowedExtraHeaders); err != nil { + return err + } + } if c.ToolPricingJSON != "" { if err := json.Unmarshal([]byte(c.ToolPricingJSON), &c.ToolPricing); err != nil { return err } } + if c.DiscoveredToolsJSON != "" { + if err := sonic.Unmarshal([]byte(c.DiscoveredToolsJSON), &c.DiscoveredTools); err != nil { + return err + } + } + if c.ToolNameMappingJSON != "" { + if err := sonic.Unmarshal([]byte(c.ToolNameMappingJSON), &c.DiscoveredToolNameMapping); err != nil { + return err + } + } return nil } diff --git a/framework/configstore/tables/modelpricing.go b/framework/configstore/tables/modelpricing.go index 9cff0bd6cb..8117fd58a1 100644 --- a/framework/configstore/tables/modelpricing.go +++ b/framework/configstore/tables/modelpricing.go @@ -15,8 +15,8 @@ type TableModelPricing struct { Architecture *schemas.Architecture `gorm:"type:text;serializer:json;default:null" json:"architecture,omitempty"` // Costs - Text - InputCostPerToken float64 `gorm:"not null" json:"input_cost_per_token"` - OutputCostPerToken float64 `gorm:"not null" json:"output_cost_per_token"` + InputCostPerToken *float64 `gorm:"default:null" json:"input_cost_per_token,omitempty"` + OutputCostPerToken *float64 `gorm:"default:null" json:"output_cost_per_token,omitempty"` InputCostPerTokenBatches *float64 `gorm:"default:null;column:input_cost_per_token_batches" json:"input_cost_per_token_batches,omitempty"` OutputCostPerTokenBatches *float64 `gorm:"default:null;column:output_cost_per_token_batches" json:"output_cost_per_token_batches,omitempty"` InputCostPerTokenPriority *float64 `gorm:"default:null;column:input_cost_per_token_priority" json:"input_cost_per_token_priority,omitempty"` diff --git a/framework/configstore/tables/oauth.go b/framework/configstore/tables/oauth.go index 81bb1a03f5..9cb65bc4af 100644 --- a/framework/configstore/tables/oauth.go +++ b/framework/configstore/tables/oauth.go @@ -11,26 +11,26 @@ import ( // TableOauthConfig represents an OAuth configuration in the database // This stores the OAuth client configuration and flow state type TableOauthConfig struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID - ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients) - ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients) - AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered) - TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered) - RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered) - RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL - Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered) - State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token - CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret) - CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider) - Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked" - TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback) - ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery - UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery - MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion) - EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min) + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID + ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients) + ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients) + AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered) + TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered) + RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered) + RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered) + State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token + CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret) + CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider) + Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked" + TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback) + ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery + UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery + MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion) + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min) } // TableName sets the table name @@ -83,13 +83,13 @@ func (c *TableOauthConfig) AfterFind(tx *gorm.DB) error { // TableOauthToken represents an OAuth token in the database // This stores the actual access and refresh tokens type TableOauthToken struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID - AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token - RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional) - TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration - Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes - LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID + AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token + RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional) + TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes + LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` @@ -132,3 +132,248 @@ func (t *TableOauthToken) AfterFind(tx *gorm.DB) error { } return nil } + +// ---------- Per-User OAuth Tables ---------- + +// TableOauthUserSession tracks pending per-user OAuth flows. +// Each record maps an OAuth state token to a specific MCP client, allowing +// the callback to associate the resulting tokens with the correct user session. +type TableOauthUserSession struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID + MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for + OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.) + State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider + RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step + CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret) + SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions) + SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups + GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken) + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens) + UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens) + Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired" + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min) + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +func (TableOauthUserSession) TableName() string { + return "oauth_user_sessions" +} + +func (s *TableOauthUserSession) BeforeSave(tx *gorm.DB) error { + if s.Status == "" { + s.Status = "pending" + } + if s.SessionToken != "" { + s.SessionTokenHash = encrypt.HashSHA256(s.SessionToken) + } + if encrypt.IsEnabled() { + if s.CodeVerifier != "" { + if err := encryptString(&s.CodeVerifier); err != nil { + return fmt.Errorf("failed to encrypt oauth user session code verifier: %w", err) + } + } + s.EncryptionStatus = EncryptionStatusEncrypted + } + return nil +} + +func (s *TableOauthUserSession) AfterFind(tx *gorm.DB) error { + if s.EncryptionStatus == EncryptionStatusEncrypted && s.CodeVerifier != "" { + if err := decryptString(&s.CodeVerifier); err != nil { + return fmt.Errorf("failed to decrypt oauth user session code verifier: %w", err) + } + } + return nil +} + +// TableOauthUserToken stores per-user OAuth credentials. +// Each record holds the access/refresh tokens for a specific user session + MCP client pair. +// Lookup is by SessionToken. +type TableOauthUserToken struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID + SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users) + SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups + VirtualKeyID *string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions) + UserID *string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions) + MCPClientID string `gorm:"type:varchar(255);not null;index:idx_vk_mcp;index:idx_user_mcp" json:"mcp_client_id"` // Which MCP server + OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config + AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token + RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token + TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes + LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +func (TableOauthUserToken) TableName() string { + return "oauth_user_tokens" +} + +func (t *TableOauthUserToken) BeforeSave(tx *gorm.DB) error { + if t.TokenType == "" { + t.TokenType = "Bearer" + } + if t.SessionToken != "" { + t.SessionTokenHash = encrypt.HashSHA256(t.SessionToken) + } + if encrypt.IsEnabled() { + if err := encryptString(&t.AccessToken); err != nil { + return fmt.Errorf("failed to encrypt oauth user access token: %w", err) + } + if err := encryptString(&t.RefreshToken); err != nil { + return fmt.Errorf("failed to encrypt oauth user refresh token: %w", err) + } + t.EncryptionStatus = EncryptionStatusEncrypted + } + return nil +} + +func (t *TableOauthUserToken) AfterFind(tx *gorm.DB) error { + if t.EncryptionStatus == EncryptionStatusEncrypted { + if err := decryptString(&t.AccessToken); err != nil { + return fmt.Errorf("failed to decrypt oauth user access token: %w", err) + } + if err := decryptString(&t.RefreshToken); err != nil { + return fmt.Errorf("failed to decrypt oauth user refresh token: %w", err) + } + } + return nil +} + +// ---------- Per-User OAuth Authorization Server Tables ---------- + +// TablePerUserOAuthClient stores dynamically registered OAuth clients (RFC 7591). +// MCP clients (like Claude Code) register themselves with Bifrost's OAuth +// authorization server to obtain a client_id for the authorization code flow. +type TablePerUserOAuthClient struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + ClientID string `gorm:"type:varchar(255);uniqueIndex;not null" json:"client_id"` + ClientName string `gorm:"type:varchar(255)" json:"client_name"` + RedirectURIs string `gorm:"type:text;not null" json:"redirect_uris"` // JSON array of allowed redirect URIs + GrantTypes string `gorm:"type:text" json:"grant_types"` // JSON array of grant types + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableName returns the table name for per-user OAuth clients. +func (TablePerUserOAuthClient) TableName() string { + return "oauth_per_user_clients" +} + +// TablePerUserOAuthSession stores Bifrost-issued access tokens for authenticated +// MCP connections. When a user authenticates via Bifrost's OAuth flow, a session +// is created. The access token is included in all subsequent MCP requests. +// Upstream provider tokens are linked via the oauth_user_tokens table. +type TablePerUserOAuthSession struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted) + AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups + RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional) + RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique β€” refresh tokens are optional) + ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth) + VirtualKey *TableVirtualKey `gorm:"foreignKey:VirtualKeyID" json:"-"` // Linked VK identity (server-only, not serialized) + UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present) + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableName returns the table name for per-user OAuth sessions. +func (TablePerUserOAuthSession) TableName() string { + return "oauth_per_user_sessions" +} + +// BeforeSave encrypts sensitive fields. +func (s *TablePerUserOAuthSession) BeforeSave(tx *gorm.DB) error { + if s.AccessToken != "" { + s.AccessTokenHash = encrypt.HashSHA256(s.AccessToken) + } + if s.RefreshToken != "" { + s.RefreshTokenHash = encrypt.HashSHA256(s.RefreshToken) + } + if encrypt.IsEnabled() { + if err := encryptString(&s.AccessToken); err != nil { + return fmt.Errorf("failed to encrypt per-user oauth access token: %w", err) + } + if s.RefreshToken != "" { + if err := encryptString(&s.RefreshToken); err != nil { + return fmt.Errorf("failed to encrypt per-user oauth refresh token: %w", err) + } + } + s.EncryptionStatus = EncryptionStatusEncrypted + } + return nil +} + +// AfterFind decrypts sensitive fields. +func (s *TablePerUserOAuthSession) AfterFind(tx *gorm.DB) error { + if s.EncryptionStatus == EncryptionStatusEncrypted { + if err := decryptString(&s.AccessToken); err != nil { + return fmt.Errorf("failed to decrypt per-user oauth access token: %w", err) + } + if s.RefreshToken != "" { + if err := decryptString(&s.RefreshToken); err != nil { + return fmt.Errorf("failed to decrypt per-user oauth refresh token: %w", err) + } + } + } + return nil +} + +// TablePerUserOAuthCode stores authorization codes during the OAuth flow. +// Codes are short-lived (5 minutes) and single-use. +type TablePerUserOAuthCode struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + Code string `gorm:"type:text;not null" json:"-"` // Authorization code + CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups + ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` + RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` + CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of requested scopes + SessionID string `gorm:"type:varchar(255);index" json:"-"` // Links to the TablePerUserOAuthSession created during consent submit + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 5 min TTL + Used bool `gorm:"default:false;not null" json:"used"` // Single-use flag + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` +} + +// BeforeSave hashes the code for secure lookups. +func (c *TablePerUserOAuthCode) BeforeSave(tx *gorm.DB) error { + if c.Code != "" { + c.CodeHash = encrypt.HashSHA256(c.Code) + } + return nil +} + +// TableName returns the table name for per-user OAuth authorization codes. +func (TablePerUserOAuthCode) TableName() string { + return "oauth_per_user_codes" +} + +// TablePerUserOAuthPendingFlow stores OAuth parameters between the authorize step +// and the final code issuance. It carries state through the multi-step consent +// screen (VK entry + per-MCP upstream auth) before a real authorization code is issued. +type TablePerUserOAuthPendingFlow struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Registered OAuth client (from authorize request) + RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Client's callback URL + CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge (echoed into the final code) + State string `gorm:"type:text;not null" json:"-"` // Original OAuth state (echoed back on final redirect) + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Set if user chose VK identity + UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Set if user chose User ID identity + BrowserSecretHash string `gorm:"type:varchar(255)" json:"-"` // SHA-256 hash of browser-binding cookie secret + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 15-min TTL + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableName returns the table name for per-user OAuth pending flows. +func (TablePerUserOAuthPendingFlow) TableName() string { + return "oauth_per_user_pending_flows" +} diff --git a/framework/configstore/tables/pricingoverride.go b/framework/configstore/tables/pricingoverride.go new file mode 100644 index 0000000000..e4b23e3069 --- /dev/null +++ b/framework/configstore/tables/pricingoverride.go @@ -0,0 +1,55 @@ +package tables + +import ( + "encoding/json" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "gorm.io/gorm" +) + +// TablePricingOverride is the persistence model for governance pricing overrides. +type TablePricingOverride struct { + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Name string `gorm:"type:varchar(255);not null" json:"name"` + ScopeKind string `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"` + VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"` + ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"` + ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"` + MatchType string `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"` + Pattern string `gorm:"type:varchar(255);not null" json:"pattern"` + RequestTypesJSON string `gorm:"type:text" json:"-"` + PricingPatchJSON string `gorm:"type:text" json:"pricing_patch,omitempty"` + ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"` +} + +// TableName returns the backing table name for governance pricing overrides. +func (TablePricingOverride) TableName() string { return "governance_pricing_overrides" } + +// BeforeSave serializes virtual fields into their JSON columns before persistence. +func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error { + if len(p.RequestTypes) > 0 { + b, err := json.Marshal(p.RequestTypes) + if err != nil { + return err + } + p.RequestTypesJSON = string(b) + } else { + p.RequestTypesJSON = "[]" + } + return nil +} + +// AfterFind restores virtual fields from their persisted JSON columns. +func (p *TablePricingOverride) AfterFind(tx *gorm.DB) error { + if p.RequestTypesJSON != "" { + if err := json.Unmarshal([]byte(p.RequestTypesJSON), &p.RequestTypes); err != nil { + return err + } + } + return nil +} diff --git a/framework/configstore/tables/provider.go b/framework/configstore/tables/provider.go index 19a54f0bb0..2b33925d82 100644 --- a/framework/configstore/tables/provider.go +++ b/framework/configstore/tables/provider.go @@ -22,7 +22,6 @@ type TableProvider struct { ProxyConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.ProxyConfig CustomProviderConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.CustomProviderConfig OpenAIConfigJSON string `gorm:"type:text" json:"-"` // JSON serialized schemas.OpenAIConfig - PricingOverridesJSON string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ProviderPricingOverride SendBackRawRequest bool `json:"send_back_raw_request"` SendBackRawResponse bool `json:"send_back_raw_response"` StoreRawRequestResponse bool `json:"store_raw_request_response"` @@ -38,9 +37,8 @@ type TableProvider struct { ProxyConfig *schemas.ProxyConfig `gorm:"-" json:"proxy_config,omitempty"` // Custom provider fields - CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"` - OpenAIConfig *schemas.OpenAIConfig `gorm:"-" json:"openai_config,omitempty"` - PricingOverrides []schemas.ProviderPricingOverride `gorm:"-" json:"pricing_overrides,omitempty"` + CustomProviderConfig *schemas.CustomProviderConfig `gorm:"-" json:"custom_provider_config,omitempty"` + OpenAIConfig *schemas.OpenAIConfig `gorm:"-" json:"openai_config,omitempty"` // Foreign keys Models []TableModel `gorm:"foreignKey:ProviderID;constraint:OnDelete:CASCADE" json:"models"` @@ -111,16 +109,6 @@ func (p *TableProvider) BeforeSave(tx *gorm.DB) error { } else { p.OpenAIConfigJSON = "" } - if p.PricingOverrides != nil { - data, err := json.Marshal(p.PricingOverrides) - if err != nil { - return err - } - p.PricingOverridesJSON = string(data) - } else { - p.PricingOverridesJSON = "" - } - // Validate governance fields if p.BudgetID != nil && strings.TrimSpace(*p.BudgetID) == "" { return fmt.Errorf("budget_id cannot be an empty string") @@ -192,13 +180,5 @@ func (p *TableProvider) AfterFind(tx *gorm.DB) error { p.OpenAIConfig = &openaiConfig } - if p.PricingOverridesJSON != "" { - var overrides []schemas.ProviderPricingOverride - if err := json.Unmarshal([]byte(p.PricingOverridesJSON), &overrides); err != nil { - return err - } - p.PricingOverrides = overrides - } - return nil } diff --git a/framework/configstore/tables/routing_rules.go b/framework/configstore/tables/routing_rules.go index 05dca8a925..2ab826be73 100644 --- a/framework/configstore/tables/routing_rules.go +++ b/framework/configstore/tables/routing_rules.go @@ -19,7 +19,7 @@ type TableRoutingRule struct { CelExpression string `gorm:"type:text;not null" json:"cel_expression"` // Routing Targets (output) β€” 1:many relationship; weights must sum to 1 - Targets []TableRoutingTarget `gorm:"foreignKey:RuleID;constraint:OnDelete:CASCADE" json:"targets,omitempty"` + Targets []TableRoutingTarget `gorm:"foreignKey:RuleID;constraint:OnDelete:CASCADE" json:"targets"` Fallbacks *string `gorm:"type:text" json:"-"` // JSON array of fallback chains ParsedFallbacks []string `gorm:"-" json:"fallbacks,omitempty"` // Parsed fallbacks from JSON @@ -31,6 +31,9 @@ type TableRoutingRule struct { Scope string `gorm:"type:varchar(50);not null;uniqueIndex:idx_routing_rule_scope_name" json:"scope"` // "global" | "team" | "customer" | "virtual_key" ScopeID *string `gorm:"type:varchar(255);uniqueIndex:idx_routing_rule_scope_name" json:"scope_id"` // nil for global, otherwise entity ID + // Chaining + ChainRule bool `gorm:"not null;default:false" json:"chain_rule"` // If true, re-evaluates routing chain after this rule matches + // Execution Priority int `gorm:"type:int;not null;default:0;index" json:"priority"` // Lower = evaluated first within scope diff --git a/framework/configstore/tables/virtualkey.go b/framework/configstore/tables/virtualkey.go index 79594a0adc..fb603202eb 100644 --- a/framework/configstore/tables/virtualkey.go +++ b/framework/configstore/tables/virtualkey.go @@ -24,17 +24,17 @@ func (TableVirtualKeyProviderConfigKey) TableName() string { // TableVirtualKeyProviderConfig represents a provider configuration for a virtual key type TableVirtualKeyProviderConfig struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"` - Provider string `gorm:"type:varchar(50);not null" json:"provider"` - Weight *float64 `json:"weight"` - AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed - BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` - RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"` + Provider string `gorm:"type:varchar(50);not null" json:"provider"` + Weight *float64 `json:"weight"` + AllowedModels schemas.WhiteList `gorm:"type:text;serializer:json" json:"allowed_models"` // ["*"] allows all models; empty denies all (deny-by-default) + AllowAllKeys bool `gorm:"default:false" json:"allow_all_keys"` // True means all keys allowed; false with empty Keys means no keys allowed (deny-by-default) + RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` // Relationships - Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"` RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"` + Budgets []TableBudget `gorm:"foreignKey:ProviderConfigID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals Keys []TableKey `gorm:"many2many:governance_virtual_key_provider_config_keys;constraint:OnDelete:CASCADE" json:"keys"` // Empty means all keys allowed for this provider } @@ -43,13 +43,12 @@ func (TableVirtualKeyProviderConfig) TableName() string { return "governance_virtual_key_provider_configs" } -// UnmarshalJSON custom unmarshaller to handle both "keys" ([]TableKey) and "allowed_keys" ([]string) formats +// UnmarshalJSON custom unmarshaller to handle "key_ids" ([]string) config-file format func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error { - // Temporary struct to capture all fields including allowed_keys type Alias TableVirtualKeyProviderConfig type TempProviderConfig struct { Alias - AllowedKeys []string `json:"allowed_keys"` // Config file format: array of key names + KeyIDs []string `json:"key_ids"` // Config file format: key identifiers (TableKey.KeyID); use ["*"] to allow all keys, empty denies all } var temp TempProviderConfig @@ -60,18 +59,32 @@ func (pc *TableVirtualKeyProviderConfig) UnmarshalJSON(data []byte) error { // Copy all standard fields *pc = TableVirtualKeyProviderConfig(temp.Alias) - // If allowed_keys is provided (config file format), convert to Keys - // This takes precedence if Keys is empty but allowed_keys has values - if len(temp.AllowedKeys) > 0 && len(pc.Keys) == 0 { - pc.Keys = make([]TableKey, len(temp.AllowedKeys)) - for i, keyName := range temp.AllowedKeys { - pc.Keys[i] = TableKey{Name: keyName} + // If key_ids is provided, convert to Keys or set AllowAllKeys + if len(temp.KeyIDs) > 0 && len(pc.Keys) == 0 { + // ["*"] means allow all keys + if len(temp.KeyIDs) == 1 && temp.KeyIDs[0] == "*" { + pc.AllowAllKeys = true + pc.Keys = nil + } else { + pc.AllowAllKeys = false + pc.Keys = make([]TableKey, len(temp.KeyIDs)) + for i, keyID := range temp.KeyIDs { + pc.Keys[i] = TableKey{KeyID: keyID} + } } } return nil } +// BeforeSave validates WhiteList fields before GORM persists the record. +func (pc *TableVirtualKeyProviderConfig) BeforeSave(tx *gorm.DB) error { + if err := pc.AllowedModels.Validate(); err != nil { + return fmt.Errorf("invalid allowed_models: %w", err) + } + return nil +} + // MarshalJSON custom marshaller to ensure AllowedModels is always an array (never null) func (pc TableVirtualKeyProviderConfig) MarshalJSON() ([]byte, error) { type Alias TableVirtualKeyProviderConfig @@ -108,7 +121,6 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { key.AzureClientSecret = nil key.AzureTenantID = nil key.AzureScopesJSON = nil - key.AzureDeploymentsJSON = nil key.AzureKeyConfig = nil // Clear all Vertex-related sensitive fields @@ -127,13 +139,8 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { key.BedrockRoleARN = nil key.BedrockExternalID = nil key.BedrockRoleSessionName = nil - key.BedrockDeploymentsJSON = nil key.BedrockKeyConfig = nil - // Clear all Replicate-related sensitive fields - key.ReplicateDeploymentsJSON = nil - key.ReplicateKeyConfig = nil - pc.Keys[i] = *key } } @@ -141,11 +148,11 @@ func (pc *TableVirtualKeyProviderConfig) AfterFind(tx *gorm.DB) error { } type TableVirtualKeyMCPConfig struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"` - MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"` - MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"` - ToolsToExecute []string `gorm:"type:text;serializer:json" json:"tools_to_execute"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + VirtualKeyID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_vk_mcpclient" json:"virtual_key_id"` + MCPClientID uint `gorm:"not null;uniqueIndex:idx_vk_mcpclient" json:"mcp_client_id"` + MCPClient TableMCPClient `gorm:"foreignKey:MCPClientID" json:"mcp_client"` + ToolsToExecute schemas.WhiteList `gorm:"type:text;serializer:json" json:"tools_to_execute"` // MCPClientName is used during config file parsing to resolve the MCP client by name. // This field is not persisted to the database - it's only used to capture @@ -158,6 +165,14 @@ func (TableVirtualKeyMCPConfig) TableName() string { return "governance_virtual_key_mcp_configs" } +// BeforeSave validates WhiteList fields before GORM persists the record. +func (mc *TableVirtualKeyMCPConfig) BeforeSave(tx *gorm.DB) error { + if err := mc.ToolsToExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_execute: %w", err) + } + return nil +} + // UnmarshalJSON custom unmarshaller to handle both "mcp_client_id" (database format) // and "mcp_client_name" (config file format) for MCP client references. func (mc *TableVirtualKeyMCPConfig) UnmarshalJSON(data []byte) error { @@ -191,20 +206,20 @@ type TableVirtualKey struct { Description string `gorm:"type:text" json:"description,omitempty"` Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:text;not null" json:"value"` // The virtual key value IsActive bool `gorm:"default:true" json:"is_active"` - ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means all providers allowed + ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means no providers allowed (deny-by-default) MCPConfigs []TableVirtualKeyMCPConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"mcp_configs"` // Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both) - TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"` - CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` - BudgetID *string `gorm:"type:varchar(255);index" json:"budget_id,omitempty"` - RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` + TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"` + CustomerID *string `gorm:"type:varchar(255);index" json:"customer_id,omitempty"` + RateLimitID *string `gorm:"type:varchar(255);index" json:"rate_limit_id,omitempty"` + CalendarAligned bool `gorm:"default:false" json:"calendar_aligned"` // When true, all budgets under this VK reset at clean calendar boundaries // Relationships Team *TableTeam `gorm:"foreignKey:TeamID" json:"team,omitempty"` Customer *TableCustomer `gorm:"foreignKey:CustomerID" json:"customer,omitempty"` - Budget *TableBudget `gorm:"foreignKey:BudgetID;onDelete:CASCADE" json:"budget,omitempty"` RateLimit *TableRateLimit `gorm:"foreignKey:RateLimitID;onDelete:CASCADE" json:"rate_limit,omitempty"` + Budgets []TableBudget `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"budgets,omitempty"` // Multiple budgets with different reset intervals // Config hash is used to detect the changes synced from config.json file // Every time we sync the config.json file, we will update the config hash diff --git a/framework/go.mod b/framework/go.mod index c467d6f32b..72dcf48ead 100644 --- a/framework/go.mod +++ b/framework/go.mod @@ -4,7 +4,7 @@ go 1.26.1 require ( github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.4.17 + github.com/maximhq/bifrost/core v1.5.1 github.com/pinecone-io/go-pinecone/v5 v5.3.0 github.com/qdrant/go-client v1.16.2 github.com/redis/go-redis/v9 v9.17.2 diff --git a/framework/go.sum b/framework/go.sum index 8fe58521ba..1172af4e86 100644 --- a/framework/go.sum +++ b/framework/go.sum @@ -193,8 +193,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/framework/logstore/matviews.go b/framework/logstore/matviews.go index bf51748693..02126a4e38 100644 --- a/framework/logstore/matviews.go +++ b/framework/logstore/matviews.go @@ -5,6 +5,7 @@ import ( "fmt" "sort" "strings" + "sync/atomic" "time" "github.com/maximhq/bifrost/core/schemas" @@ -30,6 +31,10 @@ SELECT selected_key_id, COALESCE(virtual_key_id, '') AS virtual_key_id, COALESCE(routing_rule_id, '') AS routing_rule_id, + COALESCE(user_id, '') AS user_id, + COALESCE(team_id, '') AS team_id, + COALESCE(customer_id, '') AS customer_id, + COALESCE(business_unit_id, '') AS business_unit_id, COUNT(*) AS count, SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END) AS success_count, SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS error_count, @@ -44,13 +49,13 @@ SELECT COALESCE(SUM(cost), 0) AS total_cost FROM logs WHERE status IN ('success', 'error') -GROUP BY 1, 2, 3, 4, 5, 6, 7, 8 +GROUP BY 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ` // mvLogsHourlyUniqueIdx is required for REFRESH MATERIALIZED VIEW CONCURRENTLY. const mvLogsHourlyUniqueIdx = ` CREATE UNIQUE INDEX IF NOT EXISTS mv_logs_hourly_uniq -ON mv_logs_hourly (hour, provider, model, status, object_type, selected_key_id, virtual_key_id, routing_rule_id) +ON mv_logs_hourly (hour, provider, model, status, object_type, selected_key_id, virtual_key_id, routing_rule_id, user_id, team_id, customer_id, business_unit_id) ` // mvLogsFilterdataDDL creates a materialized view of distinct filter values @@ -67,7 +72,14 @@ SELECT DISTINCT COALESCE(virtual_key_name, '') AS virtual_key_name, COALESCE(routing_rule_id, '') AS routing_rule_id, COALESCE(routing_rule_name, '') AS routing_rule_name, - COALESCE(routing_engines_used, '') AS routing_engines_used + COALESCE(routing_engines_used, '') AS routing_engines_used, + COALESCE(user_id, '') AS user_id, + COALESCE(team_id, '') AS team_id, + COALESCE(team_name, '') AS team_name, + COALESCE(customer_id, '') AS customer_id, + COALESCE(customer_name, '') AS customer_name, + COALESCE(business_unit_id, '') AS business_unit_id, + COALESCE(business_unit_name, '') AS business_unit_name FROM logs WHERE timestamp >= NOW() - INTERVAL '60 days' AND model IS NOT NULL AND model != '' @@ -77,7 +89,7 @@ WHERE timestamp >= NOW() - INTERVAL '60 days' // Includes both ID and name columns so renamed keys don't cause duplicate violations. const mvLogsFilterdataUniqueIdx = ` CREATE UNIQUE INDEX IF NOT EXISTS mv_logs_filterdata_uniq -ON mv_logs_filterdata (model, provider, selected_key_id, selected_key_name, virtual_key_id, virtual_key_name, routing_rule_id, routing_rule_name, routing_engines_used) +ON mv_logs_filterdata (model, provider, selected_key_id, selected_key_name, virtual_key_id, virtual_key_name, routing_rule_id, routing_rule_name, routing_engines_used, user_id, team_id, team_name, customer_id, customer_name, business_unit_id, business_unit_name) ` // --------------------------------------------------------------------------- @@ -138,8 +150,10 @@ func refreshMatViews(ctx context.Context, db *gorm.DB) error { } // startMatViewRefresher launches a background goroutine that periodically -// refreshes materialized views. Returns a stop function for graceful shutdown. -func startMatViewRefresher(ctx context.Context, db *gorm.DB, interval time.Duration, logger schemas.Logger) func() { +// refreshes materialized views. If readyFlag is provided and not yet true, +// it will be set to true on the first successful refresh (recovery path when +// the initial refresh failed). Returns a stop function for graceful shutdown. +func startMatViewRefresher(ctx context.Context, db *gorm.DB, interval time.Duration, logger schemas.Logger, readyFlag *atomic.Bool) func() { stopCh := make(chan struct{}) go func() { ticker := time.NewTicker(interval) @@ -149,6 +163,9 @@ func startMatViewRefresher(ctx context.Context, db *gorm.DB, interval time.Durat case <-ticker.C: if err := refreshMatViews(ctx, db); err != nil { logger.Warn(fmt.Sprintf("logstore: matview refresh failed: %s", err)) + } else if readyFlag != nil && !readyFlag.Load() { + logger.Info("logstore: materialized views are ready (recovered)") + readyFlag.Store(true) } case <-ctx.Done(): return @@ -160,10 +177,10 @@ func startMatViewRefresher(ctx context.Context, db *gorm.DB, interval time.Durat return func() { close(stopCh) } } -// canUseMatView returns true if the given filters can be served from +// canUseMatViewFilters returns true if the given filters can be served from // mv_logs_hourly. Per-row filters (content search, metadata, numeric ranges) // require the raw logs table. -func canUseMatView(f SearchFilters) bool { +func canUseMatViewFilters(f SearchFilters) bool { return f.ContentSearch == "" && len(f.MetadataFilters) == 0 && len(f.RoutingEngineUsed) == 0 && @@ -173,6 +190,15 @@ func canUseMatView(f SearchFilters) bool { !f.MissingCostOnly } +// canUseMatView checks both that materialized views are ready (created and +// populated) and that the given filters are eligible for the matview path. +// This prevents queries from hitting non-existent views during the startup +// window between migration (which drops old views) and ensureMatViews (which +// recreates them asynchronously). +func (s *RDBLogStore) canUseMatView(f SearchFilters) bool { + return s.matViewsReady.Load() && canUseMatViewFilters(f) +} + // --------------------------------------------------------------------------- // Mat-view filter helpers // --------------------------------------------------------------------------- @@ -206,6 +232,18 @@ func applyMatViewFilters(q *gorm.DB, f SearchFilters) *gorm.DB { if len(f.RoutingRuleIDs) > 0 { q = q.Where("routing_rule_id IN ?", f.RoutingRuleIDs) } + if len(f.TeamIDs) > 0 { + q = q.Where("team_id IN ?", f.TeamIDs) + } + if len(f.CustomerIDs) > 0 { + q = q.Where("customer_id IN ?", f.CustomerIDs) + } + if len(f.UserIDs) > 0 { + q = q.Where("user_id IN ?", f.UserIDs) + } + if len(f.BusinessUnitIDs) > 0 { + q = q.Where("business_unit_id IN ?", f.BusinessUnitIDs) + } return q } @@ -700,6 +738,200 @@ func (s *RDBLogStore) getProviderLatencyHistogramFromMatView(ctx context.Context return &ProviderLatencyHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Providers: providers}, nil } +// --------------------------------------------------------------------------- +// Generic dimension histogram queries (cost, tokens, latency grouped by any dimension) +// --------------------------------------------------------------------------- + +// getDimensionCostHistogramFromMatView returns time-bucketed cost data grouped by +// the specified dimension column from mv_logs_hourly. +func (s *RDBLogStore) getDimensionCostHistogramFromMatView(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error) { + dimCol := string(dimension) + var results []struct { + BucketTimestamp int64 `gorm:"column:bucket_timestamp"` + DimValue string `gorm:"column:dim_value"` + Cost float64 `gorm:"column:cost"` + } + q := s.db.WithContext(ctx).Table("mv_logs_hourly") + q = applyMatViewFilters(q, filters) + if err := q.Select(fmt.Sprintf(` + CAST(FLOOR(EXTRACT(EPOCH FROM hour) / %d) * %d AS BIGINT) AS bucket_timestamp, + %s AS dim_value, + SUM(total_cost) AS cost + `, bucketSizeSeconds, bucketSizeSeconds, dimCol)). + Group(fmt.Sprintf("bucket_timestamp, %s", dimCol)). + Order("bucket_timestamp ASC"). + Find(&results).Error; err != nil { + return nil, err + } + + type bucketAgg struct { + totalCost float64 + byDimension map[string]float64 + } + grouped := make(map[int64]*bucketAgg) + dimSet := make(map[string]struct{}) + for _, r := range results { + a, ok := grouped[r.BucketTimestamp] + if !ok { + a = &bucketAgg{byDimension: make(map[string]float64)} + grouped[r.BucketTimestamp] = a + } + a.totalCost += r.Cost + a.byDimension[r.DimValue] += r.Cost + dimSet[r.DimValue] = struct{}{} + } + + allTimestamps := generateBucketTimestamps(filters.StartTime, filters.EndTime, bucketSizeSeconds) + buckets := make([]DimensionCostHistogramBucket, 0, len(allTimestamps)) + for _, ts := range allTimestamps { + b := DimensionCostHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]float64)} + if a, ok := grouped[ts]; ok { + b.TotalCost = a.totalCost + b.ByDimension = a.byDimension + } + buckets = append(buckets, b) + } + + dimValues := sortedStringKeys(dimSet) + return &DimensionCostHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil +} + +// getDimensionTokenHistogramFromMatView returns time-bucketed token usage grouped by +// the specified dimension column from mv_logs_hourly. +func (s *RDBLogStore) getDimensionTokenHistogramFromMatView(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error) { + dimCol := string(dimension) + var results []struct { + BucketTimestamp int64 `gorm:"column:bucket_timestamp"` + DimValue string `gorm:"column:dim_value"` + PromptTokens int64 `gorm:"column:prompt_tokens"` + CompletionTokens int64 `gorm:"column:completion_tokens"` + TotalTokens int64 `gorm:"column:total_tkns"` + } + q := s.db.WithContext(ctx).Table("mv_logs_hourly") + q = applyMatViewFilters(q, filters) + if err := q.Select(fmt.Sprintf(` + CAST(FLOOR(EXTRACT(EPOCH FROM hour) / %d) * %d AS BIGINT) AS bucket_timestamp, + %s AS dim_value, + SUM(total_prompt_tokens) AS prompt_tokens, + SUM(total_completion_tokens) AS completion_tokens, + SUM(total_tokens) AS total_tkns + `, bucketSizeSeconds, bucketSizeSeconds, dimCol)). + Group(fmt.Sprintf("bucket_timestamp, %s", dimCol)). + Order("bucket_timestamp ASC"). + Find(&results).Error; err != nil { + return nil, err + } + + type dimAgg struct { + prompt, completion, total int64 + } + type bucketAgg struct { + byDimension map[string]*dimAgg + } + grouped := make(map[int64]*bucketAgg) + dimSet := make(map[string]struct{}) + for _, r := range results { + a, ok := grouped[r.BucketTimestamp] + if !ok { + a = &bucketAgg{byDimension: make(map[string]*dimAgg)} + grouped[r.BucketTimestamp] = a + } + da, ok := a.byDimension[r.DimValue] + if !ok { + da = &dimAgg{} + a.byDimension[r.DimValue] = da + } + da.prompt += r.PromptTokens + da.completion += r.CompletionTokens + da.total += r.TotalTokens + dimSet[r.DimValue] = struct{}{} + } + + allTimestamps := generateBucketTimestamps(filters.StartTime, filters.EndTime, bucketSizeSeconds) + buckets := make([]DimensionTokenHistogramBucket, 0, len(allTimestamps)) + for _, ts := range allTimestamps { + b := DimensionTokenHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]DimensionTokenStats)} + if a, ok := grouped[ts]; ok { + for dim, da := range a.byDimension { + b.ByDimension[dim] = DimensionTokenStats{ + PromptTokens: da.prompt, + CompletionTokens: da.completion, + TotalTokens: da.total, + } + } + } + buckets = append(buckets, b) + } + + dimValues := sortedStringKeys(dimSet) + return &DimensionTokenHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil +} + +// getDimensionLatencyHistogramFromMatView returns time-bucketed latency percentiles +// grouped by the specified dimension column from mv_logs_hourly. +func (s *RDBLogStore) getDimensionLatencyHistogramFromMatView(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error) { + dimCol := string(dimension) + var results []struct { + BucketTimestamp int64 `gorm:"column:bucket_timestamp"` + DimValue string `gorm:"column:dim_value"` + AvgLatency float64 `gorm:"column:avg_lat"` + P90Latency float64 `gorm:"column:p90_lat"` + P95Latency float64 `gorm:"column:p95_lat"` + P99Latency float64 `gorm:"column:p99_lat"` + TotalRequests int64 `gorm:"column:total_requests"` + } + q := s.db.WithContext(ctx).Table("mv_logs_hourly") + q = applyMatViewFilters(q, filters) + if err := q.Select(fmt.Sprintf(` + CAST(FLOOR(EXTRACT(EPOCH FROM hour) / %d) * %d AS BIGINT) AS bucket_timestamp, + %s AS dim_value, + CASE WHEN SUM(count) > 0 THEN SUM(avg_latency * count) / SUM(count) ELSE 0 END AS avg_lat, + CASE WHEN SUM(count) > 0 THEN SUM(p90_latency * count) / SUM(count) ELSE 0 END AS p90_lat, + CASE WHEN SUM(count) > 0 THEN SUM(p95_latency * count) / SUM(count) ELSE 0 END AS p95_lat, + CASE WHEN SUM(count) > 0 THEN SUM(p99_latency * count) / SUM(count) ELSE 0 END AS p99_lat, + SUM(count) AS total_requests + `, bucketSizeSeconds, bucketSizeSeconds, dimCol)). + Group(fmt.Sprintf("bucket_timestamp, %s", dimCol)). + Order("bucket_timestamp ASC"). + Find(&results).Error; err != nil { + return nil, err + } + + type bucketAgg struct { + byDimension map[string]DimensionLatencyStats + } + grouped := make(map[int64]*bucketAgg) + dimSet := make(map[string]struct{}) + for _, r := range results { + a, ok := grouped[r.BucketTimestamp] + if !ok { + a = &bucketAgg{byDimension: make(map[string]DimensionLatencyStats)} + grouped[r.BucketTimestamp] = a + } + a.byDimension[r.DimValue] = DimensionLatencyStats{ + AvgLatency: r.AvgLatency, + P90Latency: r.P90Latency, + P95Latency: r.P95Latency, + P99Latency: r.P99Latency, + TotalRequests: r.TotalRequests, + } + dimSet[r.DimValue] = struct{}{} + } + + allTimestamps := generateBucketTimestamps(filters.StartTime, filters.EndTime, bucketSizeSeconds) + buckets := make([]DimensionLatencyHistogramBucket, 0, len(allTimestamps)) + for _, ts := range allTimestamps { + b := DimensionLatencyHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]DimensionLatencyStats)} + if a, ok := grouped[ts]; ok { + b.ByDimension = a.byDimension + } + buckets = append(buckets, b) + } + + dimValues := sortedStringKeys(dimSet) + return &DimensionLatencyHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil +} + // getModelRankingsFromMatView returns models ranked by usage with trend // comparison to the previous period of equal duration from mv_logs_hourly. func (s *RDBLogStore) getModelRankingsFromMatView(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error) { @@ -795,6 +1027,85 @@ func (s *RDBLogStore) getModelRankingsFromMatView(ctx context.Context, filters S return &ModelRankingResult{Rankings: rankings}, nil } +// getUserRankingsFromMatView returns users ranked by usage with trend +// comparison to the previous period of equal duration from mv_logs_hourly. +func (s *RDBLogStore) getUserRankingsFromMatView(ctx context.Context, filters SearchFilters) (*UserRankingResult, error) { + var results []struct { + UserID string `gorm:"column:user_id"` + Total int64 `gorm:"column:total"` + TotalTokens int64 `gorm:"column:total_tkns"` + TotalCost float64 `gorm:"column:total_cost"` + } + q := s.db.WithContext(ctx).Table("mv_logs_hourly") + q = applyMatViewFilters(q, filters) + q = q.Where("user_id != ''") + if err := q.Select(` + user_id, + SUM(count) AS total, + SUM(total_tokens) AS total_tkns, + SUM(total_cost) AS total_cost + `).Group("user_id"). + Order("total DESC"). + Find(&results).Error; err != nil { + return nil, err + } + + // Previous period for trend (same duration, ending just before current start) + type prevRow struct { + UserID string `gorm:"column:user_id"` + Total int64 `gorm:"column:total"` + TotalTokens int64 `gorm:"column:total_tkns"` + TotalCost float64 `gorm:"column:total_cost"` + } + var prevResults []prevRow + if filters.StartTime != nil && filters.EndTime != nil { + duration := filters.EndTime.Sub(*filters.StartTime) + prevStart := filters.StartTime.Add(-duration) + prevEnd := filters.StartTime.Add(-time.Nanosecond) + prevFilters := filters + prevFilters.StartTime = &prevStart + prevFilters.EndTime = &prevEnd + pq := s.db.WithContext(ctx).Table("mv_logs_hourly") + pq = applyMatViewFilters(pq, prevFilters) + pq = pq.Where("user_id != ''") + if err := pq.Select(` + user_id, + SUM(count) AS total, + SUM(total_tokens) AS total_tkns, + SUM(total_cost) AS total_cost + `).Group("user_id").Find(&prevResults).Error; err != nil { + return nil, fmt.Errorf("failed to get previous period user rankings: %w", err) + } + } + + prevMap := make(map[string]int, len(prevResults)) + for i, r := range prevResults { + prevMap[r.UserID] = i + } + + rankings := make([]UserRankingWithTrend, 0, len(results)) + for _, r := range results { + entry := UserRankingEntry{ + UserID: r.UserID, + TotalRequests: r.Total, + TotalTokens: r.TotalTokens, + TotalCost: r.TotalCost, + } + urt := UserRankingWithTrend{UserRankingEntry: entry} + if idx, ok := prevMap[r.UserID]; ok { + prev := prevResults[idx] + urt.Trend = UserRankingTrend{ + HasPreviousPeriod: true, + RequestsTrend: trendPct(float64(r.Total), float64(prev.Total)), + TokensTrend: trendPct(float64(r.TotalTokens), float64(prev.TotalTokens)), + CostTrend: trendPct(r.TotalCost, prev.TotalCost), + } + } + rankings = append(rankings, urt) + } + return &UserRankingResult{Rankings: rankings}, nil +} + // --------------------------------------------------------------------------- // Filterdata from mat view // --------------------------------------------------------------------------- diff --git a/framework/logstore/migrations.go b/framework/logstore/migrations.go index 8fc9ff7b72..5e97c749f0 100644 --- a/framework/logstore/migrations.go +++ b/framework/logstore/migrations.go @@ -203,6 +203,24 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddLogsAndDashboardPerformanceIndexes(ctx, db); err != nil { return err } + if err := migrationAddImageEditInputColumn(ctx, db); err != nil { + return err + } + if err := migrationAddImageVariationInputColumn(ctx, db); err != nil { + return err + } + if err := migrationAddPluginLogsColumn(ctx, db); err != nil { + return err + } + if err := migrationAddAliasColumn(ctx, db); err != nil { + return err + } + if err := migrationAddGovernanceContextColumns(ctx, db); err != nil { + return err + } + if err := migrationRecreateMatViewsWithGovernanceColumns(ctx, db); err != nil { + return err + } if err := migrationAddRequestIDColumnToMCPToolLogs(ctx, db); err != nil { return err } @@ -251,7 +269,7 @@ func migrationUpdateObjectColumnValues(ctx context.Context, db *gorm.DB) error { tx = tx.WithContext(ctx) updateSQL := ` - UPDATE logs + UPDATE logs SET object_type = CASE object_type WHEN 'chat.completion' THEN 'chat_completion' WHEN 'text.completion' THEN 'text_completion' @@ -268,7 +286,7 @@ func migrationUpdateObjectColumnValues(ctx context.Context, db *gorm.DB) error { WHERE object_type IN ( 'chat.completion', 'text.completion', 'list', 'audio.speech', 'audio.transcription', 'chat.completion.chunk', - 'audio.speech.chunk', 'audio.transcription.chunk', + 'audio.speech.chunk', 'audio.transcription.chunk', 'response', 'response.completion.chunk' )` @@ -284,7 +302,7 @@ func migrationUpdateObjectColumnValues(ctx context.Context, db *gorm.DB) error { // Use a single CASE statement for efficient bulk rollback rollbackSQL := ` - UPDATE logs + UPDATE logs SET object_type = CASE object_type WHEN 'chat_completion' THEN 'chat.completion' WHEN 'text_completion' THEN 'text.completion' @@ -782,17 +800,17 @@ func migrationUpdateTimestampFormat(ctx context.Context, db *gorm.DB) error { updateSQL := ` UPDATE logs - SET "timestamp" = strftime('%Y-%m-%dT%H:%M:%S', "timestamp", 'utc') || '.' || + SET "timestamp" = strftime('%Y-%m-%dT%H:%M:%S', "timestamp", 'utc') || '.' || CAST(CAST(strftime('%f', "timestamp") * 1000 AS INTEGER) % 1000 AS TEXT) || 'Z' - WHERE - "timestamp" NOT LIKE '%Z' + WHERE + "timestamp" NOT LIKE '%Z' AND "timestamp" NOT LIKE '%+00%'; UPDATE logs - SET created_at = strftime('%Y-%m-%dT%H:%M:%S', created_at, 'utc') || '.' || - CAST(CAST(strftime('%f', created_at) * 1000 AS INTEGER) % 1000 AS TEXT) || + SET created_at = strftime('%Y-%m-%dT%H:%M:%S', created_at, 'utc') || '.' || + CAST(CAST(strftime('%f', created_at) * 1000 AS INTEGER) % 1000 AS TEXT) || 'Z' - WHERE - created_at NOT LIKE '%Z' + WHERE + created_at NOT LIKE '%Z' AND created_at NOT LIKE '%+00%'; ` @@ -2075,6 +2093,36 @@ var performanceIndexes = []performanceIndexDef{ name: "idx_logs_ts_provider_status", sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_ts_provider_status ON logs(timestamp, provider, status)", }, + { + table: "logs", + name: "idx_logs_alias", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_alias ON logs(alias)", + }, + { + table: "logs", + name: "idx_logs_team_id", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_team_id ON logs(team_id)", + }, + { + table: "logs", + name: "idx_logs_customer_id", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_customer_id ON logs(customer_id)", + }, + { + table: "logs", + name: "idx_logs_user_id", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_user_id ON logs(user_id)", + }, + { + table: "logs", + name: "idx_logs_business_unit_id", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_business_unit_id ON logs(business_unit_id)", + }, + { + table: "logs", + name: "idx_logs_parent_request_id", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_parent_request_id ON logs(parent_request_id) WHERE parent_request_id IS NOT NULL", + }, } // ensurePerformanceIndexes checks whether each performance GIN index exists and is @@ -2119,3 +2167,218 @@ func ensurePerformanceIndexes(ctx context.Context, conn *sql.Conn) error { return nil } + +// migrationAddImageEditInputColumn adds the image_edit_input column to the logs table. +func migrationAddImageEditInputColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_image_edit_input_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&Log{}, "image_edit_input") { + if err := migrator.AddColumn(&Log{}, "image_edit_input"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&Log{}, "image_edit_input") { + if err := migrator.DropColumn(&Log{}, "image_edit_input"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while adding image edit input column: %s", err.Error()) + } + return nil +} + +// migrationAddImageVariationInputColumn adds the image_variation_input column to the logs table. +func migrationAddImageVariationInputColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_image_variation_input_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&Log{}, "image_variation_input") { + if err := migrator.AddColumn(&Log{}, "image_variation_input"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&Log{}, "image_variation_input") { + if err := migrator.DropColumn(&Log{}, "image_variation_input"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while adding image variation input column: %s", err.Error()) + } + return nil +} + +// migrationAddPluginLogsColumn adds the plugin_logs column to the logs table. +func migrationAddPluginLogsColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_plugin_logs_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if !migrator.HasColumn(&Log{}, "plugin_logs") { + if err := migrator.AddColumn(&Log{}, "plugin_logs"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&Log{}, "plugin_logs") { + if err := migrator.DropColumn(&Log{}, "plugin_logs"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while adding plugin logs column: %s", err.Error()) + } + return nil +} + +// migrationAddAliasColumn adds the alias column to the logs table. +// The alias field stores the original model name the caller used when routing resolved it to a different model via alias mapping. +// Index creation is deferred to ensurePerformanceIndexes (called post-startup in a background goroutine) +// because CREATE INDEX CONCURRENTLY cannot run inside a transaction and a regular CREATE INDEX +// takes a SHARE lock that blocks writes on large tables during rolling deploys. +func migrationAddAliasColumn(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_alias_column", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if !mig.HasColumn(&Log{}, "alias") { + if err := mig.AddColumn(&Log{}, "alias"); err != nil { + return err + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + if mig.HasColumn(&Log{}, "alias") { + if err := mig.DropColumn(&Log{}, "alias"); err != nil { + return err + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while adding alias column: %s", err.Error()) + } + return nil +} + +// migrationAddGovernanceContextColumns adds user_id, team_id, team_name, customer_id, customer_name, +// business_unit_id, business_unit_name columns to the logs table. +func migrationAddGovernanceContextColumns(ctx context.Context, db *gorm.DB) error { + opts := *migrator.DefaultOptions + opts.UseTransaction = true + + columns := []string{"user_id", "team_id", "team_name", "customer_id", "customer_name", "business_unit_id", "business_unit_name"} + + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_add_governance_context_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + for _, col := range columns { + if !mig.HasColumn(&Log{}, col) { + if err := mig.AddColumn(&Log{}, col); err != nil { + return err + } + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + mig := tx.Migrator() + for _, col := range columns { + if mig.HasColumn(&Log{}, col) { + if err := mig.DropColumn(&Log{}, col); err != nil { + return err + } + } + } + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while adding governance context columns: %s", err.Error()) + } + return nil +} + +// migrationRecreateMatViewsWithGovernanceColumns drops and recreates materialized views +// so they include the new governance context columns (user_id, team_id, customer_id, business_unit_id). +// The views are recreated by ensureMatViews on startup, so we just need to drop the old ones. +func migrationRecreateMatViewsWithGovernanceColumns(ctx context.Context, db *gorm.DB) error { + // Materialized views are PostgreSQL-only; skip on other dialects + if db.Dialector.Name() != "postgres" { + return nil + } + opts := *migrator.DefaultOptions + opts.UseTransaction = true + m := migrator.New(db, &opts, []*migrator.Migration{{ + ID: "logs_recreate_matviews_with_governance_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + for _, view := range []string{"mv_logs_hourly", "mv_logs_filterdata"} { + if err := tx.Exec("DROP MATERIALIZED VIEW IF EXISTS " + view + " CASCADE").Error; err != nil { + return fmt.Errorf("failed to drop %s: %w", view, err) + } + } + return nil + }, + Rollback: func(tx *gorm.DB) error { + // No rollback needed β€” ensureMatViews will recreate on next startup + return nil + }, + }}) + err := m.Migrate() + if err != nil { + return fmt.Errorf("error while recreating matviews with governance columns: %s", err.Error()) + } + return nil +} diff --git a/framework/logstore/migrations_test.go b/framework/logstore/migrations_test.go index ed53ffbf76..0ee91b4b4d 100644 --- a/framework/logstore/migrations_test.go +++ b/framework/logstore/migrations_test.go @@ -45,11 +45,12 @@ func trySetupPostgresDB(t *testing.T) *gorm.DB { func setupLogsTableForGINIndexTest(t *testing.T, db *gorm.DB) { t.Helper() - // Drop existing tables and migration tracking in the correct order - // Note: The migrator uses "migrations" table by default, not "gomigrate" + // Drop existing tables and migration tracking in the correct order. + // Preserve the shared migrations table β€” only clear its rows. db.Exec("DROP INDEX IF EXISTS idx_logs_metadata_gin") db.Exec("DROP TABLE IF EXISTS logs") - db.Exec("DROP TABLE IF EXISTS migrations") + db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)") + db.Exec("DELETE FROM migrations") // Create a minimal logs table with only the columns needed for the test err := db.Exec(` @@ -72,7 +73,7 @@ func setupLogsTableForGINIndexTest(t *testing.T, db *gorm.DB) { t.Cleanup(func() { db.Exec("DROP INDEX IF EXISTS idx_logs_metadata_gin") db.Exec("DROP TABLE IF EXISTS logs") - db.Exec("DROP TABLE IF EXISTS migrations") + db.Exec("DELETE FROM migrations") }) } diff --git a/framework/logstore/postgres.go b/framework/logstore/postgres.go index 187f190a7a..df78b1735d 100644 --- a/framework/logstore/postgres.go +++ b/framework/logstore/postgres.go @@ -144,8 +144,11 @@ func newPostgresLogStore(ctx context.Context, config *PostgresConfig, logger sch logger.Warn(fmt.Sprintf("logstore: initial matview refresh failed: %s", err)) } else { logger.Info("logstore: materialized views are ready") + // Signal that matviews are ready for query use. Until this point, + // canUseMatView() returns false so all queries use raw tables. + d.matViewsReady.Store(true) } - startMatViewRefresher(context.Background(), db, 30*time.Second, logger) + startMatViewRefresher(context.Background(), db, 30*time.Second, logger, &d.matViewsReady) }() return d, nil diff --git a/framework/logstore/rdb.go b/framework/logstore/rdb.go index e9c450f26c..cd82e67b72 100644 --- a/framework/logstore/rdb.go +++ b/framework/logstore/rdb.go @@ -10,6 +10,7 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "time" "github.com/bytedance/sonic" @@ -29,6 +30,7 @@ func isValidMetadataKey(key string) bool { } const bulkUpdateCostChunkSize = 500 +const sessionLogPageLimit = 50 const ( // defaultMaxQueryLimit is a safety cap for unbounded queries (FindAll, FindAllDistinct). @@ -45,8 +47,9 @@ const ( // RDBLogStore represents a log store that uses a SQLite database. type RDBLogStore struct { - db *gorm.DB - logger schemas.Logger + db *gorm.DB + logger schemas.Logger + matViewsReady atomic.Bool } // generateBucketTimestamps generates all bucket timestamps for a time range. @@ -79,12 +82,18 @@ func (s *RDBLogStore) applyFilters(baseQuery *gorm.DB, filters SearchFilters) *g if len(filters.Models) > 0 { baseQuery = baseQuery.Where("model IN ?", filters.Models) } + if len(filters.Aliases) > 0 { + baseQuery = baseQuery.Where("alias IN ?", filters.Aliases) + } if len(filters.Status) > 0 { baseQuery = baseQuery.Where("status IN ?", filters.Status) } if len(filters.Objects) > 0 { baseQuery = baseQuery.Where("object_type IN ?", filters.Objects) } + if filters.ParentRequestID != "" { + baseQuery = baseQuery.Where("parent_request_id = ?", filters.ParentRequestID) + } if len(filters.SelectedKeyIDs) > 0 { baseQuery = baseQuery.Where("selected_key_id IN ?", filters.SelectedKeyIDs) } @@ -94,6 +103,18 @@ func (s *RDBLogStore) applyFilters(baseQuery *gorm.DB, filters SearchFilters) *g if len(filters.RoutingRuleIDs) > 0 { baseQuery = baseQuery.Where("routing_rule_id IN ?", filters.RoutingRuleIDs) } + if len(filters.TeamIDs) > 0 { + baseQuery = baseQuery.Where("team_id IN ?", filters.TeamIDs) + } + if len(filters.CustomerIDs) > 0 { + baseQuery = baseQuery.Where("customer_id IN ?", filters.CustomerIDs) + } + if len(filters.UserIDs) > 0 { + baseQuery = baseQuery.Where("user_id IN ?", filters.UserIDs) + } + if len(filters.BusinessUnitIDs) > 0 { + baseQuery = baseQuery.Where("business_unit_id IN ?", filters.BusinessUnitIDs) + } if len(filters.RoutingEngineUsed) > 0 { // Query routing engines (comma-separated values) - find logs containing ANY of the specified engines dialect := s.db.Dialector.Name() @@ -394,7 +415,7 @@ func (s *RDBLogStore) SearchLogs(ctx context.Context, filters SearchFilters, pag g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) { + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) { var err error totalCount, err = s.getCountFromMatView(gCtx, filters) return err @@ -441,16 +462,176 @@ func (s *RDBLogStore) SearchLogs(ctx context.Context, filters SearchFilters, pag }, nil } +// GetSessionLogs returns paginated logs for a single parent_request_id session. +func (s *RDBLogStore) GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error) { + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + + limit := pagination.Limit + if limit <= 0 || limit > sessionLogPageLimit { + limit = sessionLogPageLimit + } + pagination.Limit = limit + if pagination.Offset < 0 { + pagination.Offset = 0 + } + + pagination.SortBy = "timestamp" + orderDir := "ASC" + if pagination.Order == "desc" { + orderDir = "DESC" + } + orderClause := "timestamp " + orderDir + ", id " + orderDir + + baseQuery := s.db.WithContext(ctx).Model(&Log{}).Where("parent_request_id = ?", sessionID) + + var ( + totalCount int64 + logs []Log + ) + + g, gCtx := errgroup.WithContext(ctx) + + g.Go(func() error { + return s.db.WithContext(gCtx).Model(&Log{}).Where("parent_request_id = ?", sessionID).Count(&totalCount).Error + }) + + g.Go(func() error { + dataQuery := baseQuery.Session(&gorm.Session{}). + WithContext(gCtx). + Order(orderClause). + Select(s.listSelectColumns()). + Limit(limit) + if pagination.Offset > 0 { + dataQuery = dataQuery.Offset(pagination.Offset) + } + err := dataQuery.Find(&logs).Error + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return err + }) + + if err := g.Wait(); err != nil { + return nil, err + } + + pagination.TotalCount = totalCount + returnedCount := len(logs) + return &SessionDetailResult{ + SessionID: sessionID, + Logs: logs, + Pagination: pagination, + Count: totalCount, + ReturnedCount: returnedCount, + HasMore: int64(pagination.Offset+returnedCount) < totalCount, + }, nil +} + +// GetSessionSummary returns aggregate totals for a single parent_request_id session. +func (s *RDBLogStore) GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error) { + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + + var ( + count int64 + totalCost float64 + totalTokens int64 + startedAt string + latestAt string + startedRaw any + latestRaw any + ) + + // Single aggregate select keeps Count/SUM/MIN/MAX consistent against the same row snapshot + // and halves the round trips compared to running Count and the aggregate row in parallel. + row := s.db.WithContext(ctx). + Model(&Log{}). + Where("parent_request_id = ?", sessionID). + Select("COUNT(*) AS count, COALESCE(SUM(cost), 0) AS total_cost, COALESCE(SUM(total_tokens), 0) AS total_tokens, MIN(timestamp) AS started_at, MAX(timestamp) AS latest_at"). + Row() + + if err := row.Scan(&count, &totalCost, &totalTokens, &startedRaw, &latestRaw); err != nil { + return nil, err + } + + startedAt = normalizeAggregateTimestamp(startedRaw) + latestAt = normalizeAggregateTimestamp(latestRaw) + + durationMs := int64(0) + if startedAt != "" && latestAt != "" { + if startedTime, err := time.Parse(time.RFC3339Nano, startedAt); err == nil { + if latestTime, err := time.Parse(time.RFC3339Nano, latestAt); err == nil { + durationMs = latestTime.Sub(startedTime).Milliseconds() + if durationMs < 0 { + durationMs = 0 + } + } + } + } + + return &SessionSummaryResult{ + SessionID: sessionID, + Count: count, + TotalCost: totalCost, + TotalTokens: totalTokens, + StartedAt: startedAt, + LatestAt: latestAt, + DurationMs: durationMs, + }, nil +} + +func normalizeAggregateTimestamp(value any) string { + switch v := value.(type) { + case nil: + return "" + case time.Time: + return v.UTC().Format(time.RFC3339Nano) + case []byte: + return normalizeAggregateTimestamp(string(v)) + case string: + raw := strings.TrimSpace(v) + if raw == "" { + return "" + } + layouts := []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999Z07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05.999999999", + "2006-01-02T15:04:05", + } + for _, layout := range layouts { + if parsed, err := time.Parse(layout, raw); err == nil { + return parsed.UTC().Format(time.RFC3339Nano) + } + } + return raw + default: + return fmt.Sprint(v) + } +} + // listSelectColumns returns a SELECT clause for list queries that omits large // output/detail TEXT columns and uses SQL JSON functions to extract only the // last element from input_history and responses_input_history arrays. +// +// Realtime turn rows are kept intact because the logs table renders them as a +// combined Tool/User/Assistant summary and needs the full turn context. func (s *RDBLogStore) listSelectColumns() string { baseCols := strings.Join([]string{ - "id", "parent_request_id", "timestamp", "object_type", "provider", "model", + "id", "parent_request_id", "timestamp", "object_type", "provider", "model", "alias", "number_of_retries", "fallback_index", "selected_key_id", "selected_key_name", "virtual_key_id", "virtual_key_name", "routing_engines_used", "routing_rule_id", "routing_rule_name", + "user_id", "team_id", "team_name", "customer_id", "customer_name", + "business_unit_id", "business_unit_name", "speech_input", "transcription_input", "image_generation_input", "video_generation_input", "latency", "token_usage", "cost", "status", "error_details", "stream", "content_summary", "metadata", @@ -459,30 +640,40 @@ func (s *RDBLogStore) listSelectColumns() string { "created_at", }, ", ") - var inputHistoryExpr, responsesInputExpr string + var inputHistoryExpr, responsesInputExpr, outputMessageExpr string switch s.db.Dialector.Name() { case "postgres": - inputHistoryExpr = `CASE WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' + inputHistoryExpr = `CASE + WHEN object_type = 'realtime.turn' THEN input_history + WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' THEN jsonb_build_array(input_history::jsonb->-1)::text ELSE input_history END AS input_history` - responsesInputExpr = `CASE WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' + responsesInputExpr = `CASE + WHEN object_type = 'realtime.turn' THEN responses_input_history + WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' THEN jsonb_build_array(responses_input_history::jsonb->-1)::text ELSE responses_input_history END AS responses_input_history` + outputMessageExpr = `CASE WHEN object_type = 'realtime.turn' THEN output_message ELSE NULL END AS output_message` default: // sqlite - inputHistoryExpr = `CASE WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' + inputHistoryExpr = `CASE + WHEN object_type = 'realtime.turn' THEN input_history + WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' THEN json_array(json_extract(input_history, '$[' || (json_array_length(input_history) - 1) || ']')) ELSE input_history END AS input_history` - responsesInputExpr = `CASE WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' + responsesInputExpr = `CASE + WHEN object_type = 'realtime.turn' THEN responses_input_history + WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' THEN json_array(json_extract(responses_input_history, '$[' || (json_array_length(responses_input_history) - 1) || ']')) ELSE responses_input_history END AS responses_input_history` + outputMessageExpr = `CASE WHEN object_type = 'realtime.turn' THEN output_message ELSE NULL END AS output_message` } - return baseCols + ", " + inputHistoryExpr + ", " + responsesInputExpr + return baseCols + ", " + inputHistoryExpr + ", " + responsesInputExpr + ", " + outputMessageExpr } // GetStats calculates statistics for logs matching the given filters. func (s *RDBLogStore) GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) { + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) { return s.getStatsFromMatView(ctx, filters) } baseQuery := s.db.WithContext(ctx).Model(&Log{}) @@ -542,12 +733,12 @@ func (s *RDBLogStore) GetStats(ctx context.Context, filters SearchFilters) (*Sea // GetHistogram returns time-bucketed request counts for the given filters. func (s *RDBLogStore) GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 // Default to 1 hour } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } // Determine database type for SQL syntax dialect := s.db.Dialector.Name() @@ -668,12 +859,12 @@ func (s *RDBLogStore) GetHistogram(ctx context.Context, filters SearchFilters, b // GetTokenHistogram returns time-bucketed token usage for the given filters. func (s *RDBLogStore) GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getTokenHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 // Default to 1 hour } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getTokenHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -794,12 +985,12 @@ func (s *RDBLogStore) GetTokenHistogram(ctx context.Context, filters SearchFilte // GetCostHistogram returns time-bucketed cost data with model breakdown for the given filters. func (s *RDBLogStore) GetCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*CostHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getCostHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 // Default to 1 hour } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getCostHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -916,12 +1107,12 @@ func (s *RDBLogStore) GetCostHistogram(ctx context.Context, filters SearchFilter // GetModelHistogram returns time-bucketed model usage with success/error breakdown for the given filters. func (s *RDBLogStore) GetModelHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ModelHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getModelHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 // Default to 1 hour } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getModelHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -1071,12 +1262,12 @@ func computePercentile(sorted []float64, p float64) float64 { // PostgreSQL uses database-level percentile_cont aggregation (returns 1 row per bucket). // MySQL and SQLite fall back to Go-based percentile computation (loads individual latency values). func (s *RDBLogStore) GetLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*LatencyHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getLatencyHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getLatencyHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -1285,7 +1476,7 @@ func (s *RDBLogStore) buildLatencyHistogramResult(computedBuckets map[int64]Late // GetModelRankings returns models ranked by usage with trend comparison to the previous period. func (s *RDBLogStore) GetModelRankings(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) { + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) { return s.getModelRankingsFromMatView(ctx, filters) } selectClause := ` @@ -1423,6 +1614,115 @@ func (s *RDBLogStore) GetModelRankings(ctx context.Context, filters SearchFilter return &ModelRankingResult{Rankings: rankings}, nil } +// GetUserRankings returns users ranked by usage with trend comparison to the previous period. +func (s *RDBLogStore) GetUserRankings(ctx context.Context, filters SearchFilters) (*UserRankingResult, error) { + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) { + return s.getUserRankingsFromMatView(ctx, filters) + } + selectClause := ` + user_id, + COUNT(*) as total_requests, + SUM(total_tokens) as total_tokens, + COALESCE(SUM(cost), 0) as total_cost + ` + + // Query current period + currentQuery := s.db.WithContext(ctx).Model(&Log{}) + currentQuery = s.applyFilters(currentQuery, filters) + currentQuery = currentQuery.Where("status IN ?", []string{"success", "error"}) + currentQuery = currentQuery.Where("user_id IS NOT NULL AND user_id != ''") + + var currentResults []struct { + UserID string `gorm:"column:user_id"` + TotalRequests int64 `gorm:"column:total_requests"` + TotalTokens sql.NullInt64 `gorm:"column:total_tokens"` + TotalCost sql.NullFloat64 `gorm:"column:total_cost"` + } + + if err := currentQuery. + Select(selectClause). + Group("user_id"). + Order("total_requests DESC"). + Limit(defaultMaxRankingsLimit). + Find(¤tResults).Error; err != nil { + return nil, fmt.Errorf("failed to get user rankings: %w", err) + } + + // Query previous period for trend comparison + prevMap := make(map[string]UserRankingEntry) + if filters.StartTime != nil && filters.EndTime != nil { + duration := filters.EndTime.Sub(*filters.StartTime) + prevStart := filters.StartTime.Add(-duration) + prevEnd := filters.StartTime.Add(-time.Nanosecond) + + prevFilters := filters + prevFilters.StartTime = &prevStart + prevFilters.EndTime = &prevEnd + + prevQuery := s.db.WithContext(ctx).Model(&Log{}) + prevQuery = s.applyFilters(prevQuery, prevFilters) + prevQuery = prevQuery.Where("status IN ?", []string{"success", "error"}) + prevQuery = prevQuery.Where("user_id IS NOT NULL AND user_id != ''") + + if len(currentResults) > 0 { + userIDs := make([]string, len(currentResults)) + for i, r := range currentResults { + userIDs[i] = r.UserID + } + prevQuery = prevQuery.Where("user_id IN ?", userIDs) + } + + var prevResults []struct { + UserID string `gorm:"column:user_id"` + TotalRequests int64 `gorm:"column:total_requests"` + TotalTokens sql.NullInt64 `gorm:"column:total_tokens"` + TotalCost sql.NullFloat64 `gorm:"column:total_cost"` + } + + if err := prevQuery. + Select(selectClause). + Group("user_id"). + Find(&prevResults).Error; err != nil { + return nil, fmt.Errorf("failed to get previous period user rankings: %w", err) + } + + for _, r := range prevResults { + prevMap[r.UserID] = UserRankingEntry{ + UserID: r.UserID, + TotalRequests: r.TotalRequests, + TotalTokens: r.TotalTokens.Int64, + TotalCost: r.TotalCost.Float64, + } + } + } + + // Build results with trends + rankings := make([]UserRankingWithTrend, len(currentResults)) + for i, r := range currentResults { + entry := UserRankingEntry{ + UserID: r.UserID, + TotalRequests: r.TotalRequests, + TotalTokens: r.TotalTokens.Int64, + TotalCost: r.TotalCost.Float64, + } + + var trend UserRankingTrend + if prev, ok := prevMap[r.UserID]; ok && prev.TotalRequests > 0 { + trend.HasPreviousPeriod = true + trend.RequestsTrend = pctChange(float64(prev.TotalRequests), float64(r.TotalRequests)) + trend.TokensTrend = pctChange(float64(prev.TotalTokens), float64(r.TotalTokens.Int64)) + trend.CostTrend = pctChange(prev.TotalCost, r.TotalCost.Float64) + } + + rankings[i] = UserRankingWithTrend{ + UserRankingEntry: entry, + Trend: trend, + } + } + + return &UserRankingResult{Rankings: rankings}, nil +} + // pctChange computes the percentage change from old to new. func pctChange(old, new float64) float64 { if old == 0 { @@ -1433,12 +1733,12 @@ func pctChange(old, new float64) float64 { // GetProviderCostHistogram returns time-bucketed cost data with provider breakdown for the given filters. func (s *RDBLogStore) GetProviderCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderCostHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getProviderCostHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getProviderCostHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -1544,12 +1844,12 @@ func (s *RDBLogStore) GetProviderCostHistogram(ctx context.Context, filters Sear // GetProviderTokenHistogram returns time-bucketed token usage with provider breakdown for the given filters. func (s *RDBLogStore) GetProviderTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderTokenHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getProviderTokenHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getProviderTokenHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -1671,12 +1971,12 @@ func (s *RDBLogStore) GetProviderTokenHistogram(ctx context.Context, filters Sea // PostgreSQL uses database-level percentile_cont aggregation. // MySQL and SQLite fall back to Go-based percentile computation. func (s *RDBLogStore) GetProviderLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderLatencyHistogramResult, error) { - if s.db.Dialector.Name() == "postgres" && canUseMatView(filters) && bucketSizeSeconds >= 3600 { - return s.getProviderLatencyHistogramFromMatView(ctx, filters, bucketSizeSeconds) - } if bucketSizeSeconds <= 0 { bucketSizeSeconds = 3600 } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getProviderLatencyHistogramFromMatView(ctx, filters, bucketSizeSeconds) + } dialect := s.db.Dialector.Name() @@ -1941,6 +2241,317 @@ func (s *RDBLogStore) buildProviderLatencyHistogramResult(computedBuckets map[in }, nil } +// --------------------------------------------------------------------------- +// Generic dimension histogram methods +// --------------------------------------------------------------------------- + +// GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension. +// Uses the mv_logs_hourly materialized view on PostgreSQL when eligible; falls back to raw queries otherwise. +func (s *RDBLogStore) GetDimensionCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error) { + if !ValidHistogramDimensions[dimension] { + return nil, fmt.Errorf("invalid histogram dimension: %s", dimension) + } + if bucketSizeSeconds <= 0 { + bucketSizeSeconds = 3600 + } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getDimensionCostHistogramFromMatView(ctx, filters, bucketSizeSeconds, dimension) + } + dimCol := string(dimension) + dialect := s.db.Dialector.Name() + baseQuery := s.db.WithContext(ctx).Model(&Log{}) + baseQuery = s.applyFilters(baseQuery, filters) + baseQuery = baseQuery.Where("status IN ?", []string{"success", "error"}) + baseQuery = baseQuery.Where("cost IS NOT NULL AND cost > 0") + + var bucketExpr string + switch dialect { + case "sqlite": + bucketExpr = fmt.Sprintf("CAST((CAST(strftime('%%s', timestamp) AS INTEGER) / %d) * %d AS INTEGER)", bucketSizeSeconds, bucketSizeSeconds) + default: + bucketExpr = fmt.Sprintf("CAST(FLOOR(EXTRACT(EPOCH FROM timestamp) / %d) * %d AS BIGINT)", bucketSizeSeconds, bucketSizeSeconds) + } + + var results []struct { + BucketTimestamp int64 `gorm:"column:bucket_timestamp"` + DimValue string `gorm:"column:dim_value"` + Cost float64 `gorm:"column:cost"` + } + if err := baseQuery.Select(fmt.Sprintf(` + %s AS bucket_timestamp, + COALESCE(%s, '') AS dim_value, + SUM(cost) AS cost + `, bucketExpr, dimCol)). + Group(fmt.Sprintf("bucket_timestamp, %s", dimCol)). + Order("bucket_timestamp ASC"). + Find(&results).Error; err != nil { + return nil, err + } + + type bucketAgg struct { + totalCost float64 + byDimension map[string]float64 + } + grouped := make(map[int64]*bucketAgg) + dimSet := make(map[string]struct{}) + for _, r := range results { + a, ok := grouped[r.BucketTimestamp] + if !ok { + a = &bucketAgg{byDimension: make(map[string]float64)} + grouped[r.BucketTimestamp] = a + } + a.totalCost += r.Cost + a.byDimension[r.DimValue] += r.Cost + dimSet[r.DimValue] = struct{}{} + } + + dimValues := sortedStringKeys(dimSet) + allTimestamps := generateBucketTimestamps(filters.StartTime, filters.EndTime, bucketSizeSeconds) + + // If no time range specified, build buckets directly from query results + if len(allTimestamps) == 0 { + keys := make([]int64, 0, len(grouped)) + for ts := range grouped { + keys = append(keys, ts) + } + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + buckets := make([]DimensionCostHistogramBucket, 0, len(keys)) + for _, ts := range keys { + a := grouped[ts] + buckets = append(buckets, DimensionCostHistogramBucket{ + Timestamp: time.Unix(ts, 0).UTC(), + TotalCost: a.totalCost, + ByDimension: a.byDimension, + }) + } + return &DimensionCostHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil + } + + buckets := make([]DimensionCostHistogramBucket, 0, len(allTimestamps)) + for _, ts := range allTimestamps { + b := DimensionCostHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]float64)} + if a, ok := grouped[ts]; ok { + b.TotalCost = a.totalCost + b.ByDimension = a.byDimension + } + buckets = append(buckets, b) + } + + return &DimensionCostHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil +} + +// GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension. +// Uses the mv_logs_hourly materialized view on PostgreSQL when eligible; falls back to raw queries otherwise. +func (s *RDBLogStore) GetDimensionTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error) { + if !ValidHistogramDimensions[dimension] { + return nil, fmt.Errorf("invalid histogram dimension: %s", dimension) + } + if bucketSizeSeconds <= 0 { + bucketSizeSeconds = 3600 + } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getDimensionTokenHistogramFromMatView(ctx, filters, bucketSizeSeconds, dimension) + } + dimCol := string(dimension) + dialect := s.db.Dialector.Name() + baseQuery := s.db.WithContext(ctx).Model(&Log{}) + baseQuery = s.applyFilters(baseQuery, filters) + baseQuery = baseQuery.Where("status IN ?", []string{"success", "error"}) + + var bucketExpr string + switch dialect { + case "sqlite": + bucketExpr = fmt.Sprintf("CAST((CAST(strftime('%%s', timestamp) AS INTEGER) / %d) * %d AS INTEGER)", bucketSizeSeconds, bucketSizeSeconds) + default: + bucketExpr = fmt.Sprintf("CAST(FLOOR(EXTRACT(EPOCH FROM timestamp) / %d) * %d AS BIGINT)", bucketSizeSeconds, bucketSizeSeconds) + } + + var results []struct { + BucketTimestamp int64 `gorm:"column:bucket_timestamp"` + DimValue string `gorm:"column:dim_value"` + PromptTokens int64 `gorm:"column:prompt_tokens"` + CompletionTokens int64 `gorm:"column:completion_tokens"` + TotalTokens int64 `gorm:"column:total_tkns"` + } + if err := baseQuery.Select(fmt.Sprintf(` + %s AS bucket_timestamp, + COALESCE(%s, '') AS dim_value, + COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, + COALESCE(SUM(completion_tokens), 0) AS completion_tokens, + COALESCE(SUM(total_tokens), 0) AS total_tkns + `, bucketExpr, dimCol)). + Group(fmt.Sprintf("bucket_timestamp, %s", dimCol)). + Order("bucket_timestamp ASC"). + Find(&results).Error; err != nil { + return nil, err + } + + type dimAgg struct { + prompt, completion, total int64 + } + type bucketAgg struct { + byDimension map[string]*dimAgg + } + grouped := make(map[int64]*bucketAgg) + dimSet := make(map[string]struct{}) + for _, r := range results { + a, ok := grouped[r.BucketTimestamp] + if !ok { + a = &bucketAgg{byDimension: make(map[string]*dimAgg)} + grouped[r.BucketTimestamp] = a + } + da, ok := a.byDimension[r.DimValue] + if !ok { + da = &dimAgg{} + a.byDimension[r.DimValue] = da + } + da.prompt += r.PromptTokens + da.completion += r.CompletionTokens + da.total += r.TotalTokens + dimSet[r.DimValue] = struct{}{} + } + + dimValues := sortedStringKeys(dimSet) + allTimestamps := generateBucketTimestamps(filters.StartTime, filters.EndTime, bucketSizeSeconds) + + // If no time range specified, build buckets directly from query results + if len(allTimestamps) == 0 { + keys := make([]int64, 0, len(grouped)) + for ts := range grouped { + keys = append(keys, ts) + } + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + buckets := make([]DimensionTokenHistogramBucket, 0, len(keys)) + for _, ts := range keys { + a := grouped[ts] + b := DimensionTokenHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]DimensionTokenStats)} + for dim, da := range a.byDimension { + b.ByDimension[dim] = DimensionTokenStats{ + PromptTokens: da.prompt, + CompletionTokens: da.completion, + TotalTokens: da.total, + } + } + buckets = append(buckets, b) + } + return &DimensionTokenHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil + } + + buckets := make([]DimensionTokenHistogramBucket, 0, len(allTimestamps)) + for _, ts := range allTimestamps { + b := DimensionTokenHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]DimensionTokenStats)} + if a, ok := grouped[ts]; ok { + for dim, da := range a.byDimension { + b.ByDimension[dim] = DimensionTokenStats{ + PromptTokens: da.prompt, + CompletionTokens: da.completion, + TotalTokens: da.total, + } + } + } + buckets = append(buckets, b) + } + + return &DimensionTokenHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil +} + +// GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension. +// Uses the mv_logs_hourly materialized view on PostgreSQL when eligible; falls back to raw queries otherwise. +// The fallback path computes AVG latency only (no percentiles) since percentile_cont is Postgres-specific. +func (s *RDBLogStore) GetDimensionLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error) { + if !ValidHistogramDimensions[dimension] { + return nil, fmt.Errorf("invalid histogram dimension: %s", dimension) + } + if bucketSizeSeconds <= 0 { + bucketSizeSeconds = 3600 + } + if s.db.Dialector.Name() == "postgres" && s.canUseMatView(filters) && bucketSizeSeconds >= 3600 { + return s.getDimensionLatencyHistogramFromMatView(ctx, filters, bucketSizeSeconds, dimension) + } + dimCol := string(dimension) + dialect := s.db.Dialector.Name() + baseQuery := s.db.WithContext(ctx).Model(&Log{}) + baseQuery = s.applyFilters(baseQuery, filters) + baseQuery = baseQuery.Where("status IN ?", []string{"success", "error"}) + baseQuery = baseQuery.Where("latency IS NOT NULL") + + var bucketExpr string + switch dialect { + case "sqlite": + bucketExpr = fmt.Sprintf("CAST((CAST(strftime('%%s', timestamp) AS INTEGER) / %d) * %d AS INTEGER)", bucketSizeSeconds, bucketSizeSeconds) + default: + bucketExpr = fmt.Sprintf("CAST(FLOOR(EXTRACT(EPOCH FROM timestamp) / %d) * %d AS BIGINT)", bucketSizeSeconds, bucketSizeSeconds) + } + + var results []struct { + BucketTimestamp int64 `gorm:"column:bucket_timestamp"` + DimValue string `gorm:"column:dim_value"` + AvgLatency float64 `gorm:"column:avg_lat"` + TotalRequests int64 `gorm:"column:total_requests"` + } + if err := baseQuery.Select(fmt.Sprintf(` + %s AS bucket_timestamp, + COALESCE(%s, '') AS dim_value, + COALESCE(AVG(latency), 0) AS avg_lat, + COUNT(*) AS total_requests + `, bucketExpr, dimCol)). + Group(fmt.Sprintf("bucket_timestamp, %s", dimCol)). + Order("bucket_timestamp ASC"). + Find(&results).Error; err != nil { + return nil, err + } + + type bucketAgg struct { + byDimension map[string]DimensionLatencyStats + } + grouped := make(map[int64]*bucketAgg) + dimSet := make(map[string]struct{}) + for _, r := range results { + a, ok := grouped[r.BucketTimestamp] + if !ok { + a = &bucketAgg{byDimension: make(map[string]DimensionLatencyStats)} + grouped[r.BucketTimestamp] = a + } + a.byDimension[r.DimValue] = DimensionLatencyStats{ + AvgLatency: r.AvgLatency, + TotalRequests: r.TotalRequests, + } + dimSet[r.DimValue] = struct{}{} + } + + dimValues := sortedStringKeys(dimSet) + allTimestamps := generateBucketTimestamps(filters.StartTime, filters.EndTime, bucketSizeSeconds) + + // If no time range specified, build buckets directly from query results + if len(allTimestamps) == 0 { + keys := make([]int64, 0, len(grouped)) + for ts := range grouped { + keys = append(keys, ts) + } + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + buckets := make([]DimensionLatencyHistogramBucket, 0, len(keys)) + for _, ts := range keys { + a := grouped[ts] + buckets = append(buckets, DimensionLatencyHistogramBucket{ + Timestamp: time.Unix(ts, 0).UTC(), + ByDimension: a.byDimension, + }) + } + return &DimensionLatencyHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil + } + + buckets := make([]DimensionLatencyHistogramBucket, 0, len(allTimestamps)) + for _, ts := range allTimestamps { + b := DimensionLatencyHistogramBucket{Timestamp: time.Unix(ts, 0).UTC(), ByDimension: make(map[string]DimensionLatencyStats)} + if a, ok := grouped[ts]; ok { + b.ByDimension = a.byDimension + } + buckets = append(buckets, b) + } + + return &DimensionLatencyHistogramResult{Buckets: buckets, BucketSizeSeconds: bucketSizeSeconds, Dimension: dimension, DimensionValues: dimValues}, nil +} + // HasLogs checks if there are any logs in the database. func (s *RDBLogStore) HasLogs(ctx context.Context) (bool, error) { var log Log @@ -1966,6 +2577,20 @@ func (s *RDBLogStore) FindByID(ctx context.Context, id string) (*Log, error) { return &log, nil } +// IsLogEntryPresent checks if a log entry is present in the database. +// Here we dont load entire log entry in memory - just check if it exists. +func (s *RDBLogStore) IsLogEntryPresent(ctx context.Context, id string) (bool, error) { + var log Log + err := s.db.WithContext(ctx).Select("id").Where("id = ?", id).First(&log).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + return true, nil +} + // FindFirst gets a log entry from the database. func (s *RDBLogStore) FindFirst(ctx context.Context, query any, fields ...string) (*Log, error) { var log Log @@ -1990,7 +2615,7 @@ func (s *RDBLogStore) Flush(ctx context.Context, since time.Time) error { // GetDistinctModels returns all unique non-empty model values using SELECT DISTINCT. // Scoped to recent data to avoid full table scans. func (s *RDBLogStore) GetDistinctModels(ctx context.Context) ([]string, error) { - if s.db.Dialector.Name() == "postgres" { + if s.db.Dialector.Name() == "postgres" && s.matViewsReady.Load() { return s.getDistinctModelsFromMatView(ctx) } cutoff := time.Now().UTC().AddDate(0, 0, -defaultFilterDataCutoffDays) @@ -2004,21 +2629,42 @@ func (s *RDBLogStore) GetDistinctModels(ctx context.Context) ([]string, error) { return models, nil } +// GetDistinctAliases returns all unique non-empty alias values using SELECT DISTINCT. +// Scoped to recent data to avoid full table scans. +func (s *RDBLogStore) GetDistinctAliases(ctx context.Context) ([]string, error) { + cutoff := time.Now().UTC().AddDate(0, 0, -defaultFilterDataCutoffDays) + var aliases []string + err := s.db.WithContext(ctx).Model(&Log{}). + Where("alias IS NOT NULL AND alias != '' AND timestamp >= ?", cutoff). + Distinct("alias").Limit(defaultFilterDataLimit).Pluck("alias", &aliases).Error + if err != nil { + return nil, fmt.Errorf("failed to get distinct aliases: %w", err) + } + return aliases, nil +} + // allowedKeyPairColumns is a whitelist of column names that can be used in GetDistinctKeyPairs // to prevent SQL injection from interpolated column names. var allowedKeyPairColumns = map[string]struct{}{ - "selected_key_id": {}, - "selected_key_name": {}, - "virtual_key_id": {}, - "virtual_key_name": {}, - "routing_rule_id": {}, - "routing_rule_name": {}, + "selected_key_id": {}, + "selected_key_name": {}, + "virtual_key_id": {}, + "virtual_key_name": {}, + "routing_rule_id": {}, + "routing_rule_name": {}, + "team_id": {}, + "team_name": {}, + "customer_id": {}, + "customer_name": {}, + "user_id": {}, + "business_unit_id": {}, + "business_unit_name": {}, } // GetDistinctKeyPairs returns unique non-empty ID-Name pairs for the given columns using SELECT DISTINCT. // idCol and nameCol must be valid column names (e.g., "selected_key_id", "selected_key_name"). func (s *RDBLogStore) GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error) { - if s.db.Dialector.Name() == "postgres" { + if s.db.Dialector.Name() == "postgres" && s.matViewsReady.Load() { return s.getDistinctKeyPairsFromMatView(ctx, idCol, nameCol) } if _, ok := allowedKeyPairColumns[idCol]; !ok { @@ -2043,7 +2689,7 @@ func (s *RDBLogStore) GetDistinctKeyPairs(ctx context.Context, idCol, nameCol st // GetDistinctRoutingEngines returns all unique routing engine values from the comma-separated column. // Scoped to recent data to avoid full table scans. func (s *RDBLogStore) GetDistinctRoutingEngines(ctx context.Context) ([]string, error) { - if s.db.Dialector.Name() == "postgres" { + if s.db.Dialector.Name() == "postgres" && s.matViewsReady.Load() { return s.getDistinctRoutingEnginesFromMatView(ctx) } cutoff := time.Now().UTC().AddDate(0, 0, -defaultFilterDataCutoffDays) diff --git a/framework/logstore/rdb_postgres_perf_test.go b/framework/logstore/rdb_postgres_perf_test.go index 404793acb2..c7fab23103 100644 --- a/framework/logstore/rdb_postgres_perf_test.go +++ b/framework/logstore/rdb_postgres_perf_test.go @@ -21,13 +21,15 @@ func setupPerfTestDB(t *testing.T) (*RDBLogStore, *gorm.DB) { t.Skip("Postgres not available, skipping test") } - // Clean slate + // Clean slate β€” drop test-owned tables but preserve the shared migrations + // table so concurrent test packages (e.g. configstore) are not disrupted. db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_hourly CASCADE") db.Exec("DROP MATERIALIZED VIEW IF EXISTS mv_logs_filterdata CASCADE") db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE") db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE") db.Exec("DROP TABLE IF EXISTS logs CASCADE") - db.Exec("DROP TABLE IF EXISTS migrations CASCADE") + db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)") + db.Exec("DELETE FROM migrations") ctx := context.Background() err := triggerMigrations(ctx, db) @@ -47,7 +49,7 @@ func setupPerfTestDB(t *testing.T) (*RDBLogStore, *gorm.DB) { db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE") db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE") db.Exec("DROP TABLE IF EXISTS logs CASCADE") - db.Exec("DROP TABLE IF EXISTS migrations CASCADE") + db.Exec("DELETE FROM migrations") }) return store, db @@ -461,7 +463,8 @@ func TestEnsurePerformanceIndexes(t *testing.T) { db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE") db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE") db.Exec("DROP TABLE IF EXISTS logs CASCADE") - db.Exec("DROP TABLE IF EXISTS migrations CASCADE") + db.Exec("CREATE TABLE IF NOT EXISTS migrations (id VARCHAR(255) PRIMARY KEY)") + db.Exec("DELETE FROM migrations") ctx := context.Background() err := triggerMigrations(ctx, db) @@ -474,7 +477,7 @@ func TestEnsurePerformanceIndexes(t *testing.T) { db.Exec("DROP TABLE IF EXISTS mcp_tool_logs CASCADE") db.Exec("DROP TABLE IF EXISTS async_jobs CASCADE") db.Exec("DROP TABLE IF EXISTS logs CASCADE") - db.Exec("DROP TABLE IF EXISTS migrations CASCADE") + db.Exec("DELETE FROM migrations") }) conn := acquirePerfTestSQLConn(t, ctx, db) diff --git a/framework/logstore/store.go b/framework/logstore/store.go index 6078ab3eea..f209672442 100644 --- a/framework/logstore/store.go +++ b/framework/logstore/store.go @@ -24,11 +24,14 @@ type LogStore interface { CreateIfNotExists(ctx context.Context, entry *Log) error BatchCreateIfNotExists(ctx context.Context, entries []*Log) error FindByID(ctx context.Context, id string) (*Log, error) + IsLogEntryPresent(ctx context.Context, id string) (bool, error) FindFirst(ctx context.Context, query any, fields ...string) (*Log, error) FindAll(ctx context.Context, query any, fields ...string) ([]*Log, error) FindAllDistinct(ctx context.Context, query any, fields ...string) ([]*Log, error) HasLogs(ctx context.Context) (bool, error) SearchLogs(ctx context.Context, filters SearchFilters, pagination PaginationOptions) (*SearchResult, error) + GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error) + GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error) GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error) GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error) GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error) @@ -39,6 +42,13 @@ type LogStore interface { GetProviderTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderTokenHistogramResult, error) GetProviderLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*ProviderLatencyHistogramResult, error) GetModelRankings(ctx context.Context, filters SearchFilters) (*ModelRankingResult, error) + GetUserRankings(ctx context.Context, filters SearchFilters) (*UserRankingResult, error) + // GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension (e.g., team_id, customer_id). + GetDimensionCostHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionCostHistogramResult, error) + // GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension. + GetDimensionTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionTokenHistogramResult, error) + // GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension. + GetDimensionLatencyHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64, dimension HistogramDimension) (*DimensionLatencyHistogramResult, error) Update(ctx context.Context, id string, entry any) error BulkUpdateCost(ctx context.Context, updates map[string]float64) error Flush(ctx context.Context, since time.Time) error @@ -49,6 +59,7 @@ type LogStore interface { // Distinct value methods for filter data GetDistinctModels(ctx context.Context) ([]string, error) + GetDistinctAliases(ctx context.Context) ([]string, error) GetDistinctKeyPairs(ctx context.Context, idCol, nameCol string) ([]KeyPairResult, error) GetDistinctRoutingEngines(ctx context.Context) ([]string, error) GetDistinctMetadataKeys(ctx context.Context) (map[string][]string, error) diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index c8bd05a3c5..20c28bfea4 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -29,22 +29,28 @@ const ( // SearchFilters represents the available filters for log searches type SearchFilters struct { - Providers []string `json:"providers,omitempty"` - Models []string `json:"models,omitempty"` - Status []string `json:"status,omitempty"` - Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) - SelectedKeyIDs []string `json:"selected_key_ids,omitempty"` - VirtualKeyIDs []string `json:"virtual_key_ids,omitempty"` - RoutingRuleIDs []string `json:"routing_rule_ids,omitempty"` - RoutingEngineUsed []string `json:"routing_engine_used,omitempty"` // For filtering by routing engine (routing-rule, governance, loadbalancing) - StartTime *time.Time `json:"start_time,omitempty"` - EndTime *time.Time `json:"end_time,omitempty"` - MinLatency *float64 `json:"min_latency,omitempty"` - MaxLatency *float64 `json:"max_latency,omitempty"` - MinTokens *int `json:"min_tokens,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - MinCost *float64 `json:"min_cost,omitempty"` - MaxCost *float64 `json:"max_cost,omitempty"` + Providers []string `json:"providers,omitempty"` + Models []string `json:"models,omitempty"` + Aliases []string `json:"aliases,omitempty"` + Status []string `json:"status,omitempty"` + Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) + ParentRequestID string `json:"parent_request_id,omitempty"` + SelectedKeyIDs []string `json:"selected_key_ids,omitempty"` + VirtualKeyIDs []string `json:"virtual_key_ids,omitempty"` + RoutingRuleIDs []string `json:"routing_rule_ids,omitempty"` + TeamIDs []string `json:"team_ids,omitempty"` + CustomerIDs []string `json:"customer_ids,omitempty"` + UserIDs []string `json:"user_ids,omitempty"` + BusinessUnitIDs []string `json:"business_unit_ids,omitempty"` + RoutingEngineUsed []string `json:"routing_engine_used,omitempty"` // For filtering by routing engine (routing-rule, governance, loadbalancing) + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + MinLatency *float64 `json:"min_latency,omitempty"` + MaxLatency *float64 `json:"max_latency,omitempty"` + MinTokens *int `json:"min_tokens,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MinCost *float64 `json:"min_cost,omitempty"` + MaxCost *float64 `json:"max_cost,omitempty"` MissingCostOnly bool `json:"missing_cost_only,omitempty"` ContentSearch string `json:"content_search,omitempty"` MetadataFilters map[string]string `json:"metadata_filters,omitempty"` // key=metadataKey, value=metadataValue for filtering by metadata @@ -67,6 +73,25 @@ type SearchResult struct { HasLogs bool `json:"has_logs"` } +type SessionDetailResult struct { + SessionID string `json:"session_id"` + Logs []Log `json:"logs"` + Pagination PaginationOptions `json:"pagination"` + Count int64 `json:"count"` + ReturnedCount int `json:"returned_count"` + HasMore bool `json:"has_more"` +} + +type SessionSummaryResult struct { + SessionID string `json:"session_id"` + Count int64 `json:"count"` + TotalCost float64 `json:"total_cost"` + TotalTokens int64 `json:"total_tokens"` + StartedAt string `json:"started_at,omitempty"` + LatestAt string `json:"latest_at,omitempty"` + DurationMs int64 `json:"duration_ms"` +} + type SearchStats struct { TotalRequests int64 `json:"total_requests"` SuccessRate float64 `json:"success_rate"` // Percentage of successful requests @@ -78,59 +103,70 @@ type SearchStats struct { // Log represents a complete log entry for a request/response cycle // This is the GORM model with appropriate tags type Log struct { - ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` - ParentRequestID *string `gorm:"type:varchar(255)" json:"parent_request_id"` - Timestamp time.Time `gorm:"index;index:idx_logs_ts_provider_status,priority:1;not null" json:"timestamp"` - Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding - Provider string `gorm:"type:varchar(255);index;index:idx_logs_ts_provider_status,priority:2;not null" json:"provider"` - Model string `gorm:"type:varchar(255);index;not null" json:"model"` - NumberOfRetries int `gorm:"default:0" json:"number_of_retries"` - FallbackIndex int `gorm:"default:0" json:"fallback_index"` - SelectedKeyID string `gorm:"type:varchar(255);index:idx_logs_selected_key_id" json:"selected_key_id"` - SelectedKeyName string `gorm:"type:varchar(255)" json:"selected_key_name"` - VirtualKeyID *string `gorm:"type:varchar(255);index:idx_logs_virtual_key_id" json:"virtual_key_id"` - VirtualKeyName *string `gorm:"type:varchar(255)" json:"virtual_key_name"` - RoutingEnginesUsedStr *string `gorm:"type:varchar(255);column:routing_engines_used" json:"-"` // Comma-separated routing engines - RoutingRuleID *string `gorm:"type:varchar(255);index:idx_logs_routing_rule_id" json:"routing_rule_id"` - RoutingRuleName *string `gorm:"type:varchar(255)" json:"routing_rule_name"` - InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ChatMessage - ResponsesInputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ResponsesMessage - OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ChatMessage - ResponsesOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ResponsesMessage - EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized embedding response data - RerankOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.RerankResult - Params string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ModelParameters - Tools string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Tool - ToolCalls string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ToolCall (For backward compatibility, tool calls are now in the content) - SpeechInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.SpeechInput - TranscriptionInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.TranscriptionInput - ImageGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageGenerationInput - VideoGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.VideoGenerationInput - SpeechOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostSpeech - TranscriptionOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostTranscribe - ImageGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostImageGenerationResponse - ListModelsOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Model - VideoGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoGenerationResponse - VideoRetrieveOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoRetrieveResponse - VideoDownloadOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDownloadResponse - VideoListOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoListResponse - VideoDeleteOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDeleteResponse - CacheDebug string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostCacheDebug - Latency *float64 `gorm:"index:idx_logs_latency" json:"latency,omitempty"` - TokenUsage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.LLMUsage - Cost *float64 `gorm:"index" json:"cost,omitempty"` // Cost in dollars (total cost of the request - includes cache lookup cost) - Status string `gorm:"type:varchar(50);index;index:idx_logs_ts_provider_status,priority:3;not null" json:"status"` // "processing", "success", or "error" - ErrorDetails string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostError - Stream bool `gorm:"default:false" json:"stream"` // true if this was a streaming response - ContentSummary string `gorm:"type:text" json:"-"` - RawRequest string `gorm:"type:text" json:"raw_request"` // Populated when `send-back-raw-request` is on - RawResponse string `gorm:"type:text" json:"raw_response"` // Populated when `send-back-raw-response` is on + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + ParentRequestID *string `gorm:"type:varchar(255);index" json:"parent_request_id"` + Timestamp time.Time `gorm:"index;index:idx_logs_ts_provider_status,priority:1;not null" json:"timestamp"` + Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding + Provider string `gorm:"type:varchar(255);index;index:idx_logs_ts_provider_status,priority:2;not null" json:"provider"` + Model string `gorm:"type:varchar(255);index;not null" json:"model"` + Alias *string `gorm:"type:varchar(255);index" json:"alias,omitempty"` // Set when model was resolved via alias mapping; the original name the caller used + NumberOfRetries int `gorm:"default:0" json:"number_of_retries"` + FallbackIndex int `gorm:"default:0" json:"fallback_index"` + SelectedKeyID string `gorm:"type:varchar(255);index:idx_logs_selected_key_id" json:"selected_key_id"` + SelectedKeyName string `gorm:"type:varchar(255)" json:"selected_key_name"` + VirtualKeyID *string `gorm:"type:varchar(255);index:idx_logs_virtual_key_id" json:"virtual_key_id"` + VirtualKeyName *string `gorm:"type:varchar(255)" json:"virtual_key_name"` + RoutingEnginesUsedStr *string `gorm:"type:varchar(255);column:routing_engines_used" json:"-"` // Comma-separated routing engines + RoutingRuleID *string `gorm:"type:varchar(255);index:idx_logs_routing_rule_id" json:"routing_rule_id"` + RoutingRuleName *string `gorm:"type:varchar(255)" json:"routing_rule_name"` + UserID *string `gorm:"type:varchar(255);index:idx_logs_user_id" json:"user_id"` + TeamID *string `gorm:"type:varchar(255);index:idx_logs_team_id" json:"team_id"` + TeamName *string `gorm:"type:varchar(255)" json:"team_name"` + CustomerID *string `gorm:"type:varchar(255);index:idx_logs_customer_id" json:"customer_id"` + CustomerName *string `gorm:"type:varchar(255)" json:"customer_name"` + BusinessUnitID *string `gorm:"type:varchar(255);index:idx_logs_business_unit_id" json:"business_unit_id"` + BusinessUnitName *string `gorm:"type:varchar(255)" json:"business_unit_name"` + InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ChatMessage + ResponsesInputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ResponsesMessage + OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ChatMessage + ResponsesOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ResponsesMessage + EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized [][]float32 + RerankOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.RerankResult + Params string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ModelParameters + Tools string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Tool + ToolCalls string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ToolCall (For backward compatibility, tool calls are now in the content) + SpeechInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.SpeechInput + TranscriptionInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.TranscriptionInput + ImageGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageGenerationInput + ImageEditInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageEditInput + ImageVariationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageVariationInput + VideoGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.VideoGenerationInput + SpeechOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostSpeech + TranscriptionOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostTranscribe + ImageGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostImageGenerationResponse + ListModelsOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Model + VideoGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoGenerationResponse + VideoRetrieveOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoRetrieveResponse + VideoDownloadOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDownloadResponse + VideoListOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoListResponse + VideoDeleteOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDeleteResponse + CacheDebug string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostCacheDebug + Latency *float64 `gorm:"index:idx_logs_latency" json:"latency,omitempty"` + TokenUsage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.LLMUsage + Cost *float64 `gorm:"index" json:"cost,omitempty"` // Cost in dollars (total cost of the request - includes cache lookup cost) + Status string `gorm:"type:varchar(50);index;index:idx_logs_ts_provider_status,priority:3;not null" json:"status"` // "processing", "success", or "error" + ErrorDetails string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostError + Stream bool `gorm:"default:false" json:"stream"` // true if this was a streaming response + ContentSummary string `gorm:"type:text" json:"-"` + RawRequest string `gorm:"type:text" json:"raw_request"` // Populated when `send-back-raw-request` is on + RawResponse string `gorm:"type:text" json:"raw_response"` // Populated when `send-back-raw-response` is on PassthroughRequestBody string `gorm:"type:text" json:"passthrough_request_body,omitempty"` // Raw body for passthrough requests (UTF-8) PassthroughResponseBody string `gorm:"type:text" json:"passthrough_response_body,omitempty"` // Raw body for passthrough responses (UTF-8) - RoutingEngineLogs string `gorm:"type:text" json:"routing_engine_logs,omitempty"` // Formatted routing engine decision logs - Metadata *string `gorm:"type:text" json:"-"` // JSON serialized map[string]interface{} - IsLargePayloadRequest bool `gorm:"default:false" json:"is_large_payload_request"` - IsLargePayloadResponse bool `gorm:"default:false" json:"is_large_payload_response"` + RoutingEngineLogs string `gorm:"type:text" json:"routing_engine_logs,omitempty"` // Formatted routing engine decision logs + PluginLogs string `gorm:"type:text" json:"plugin_logs,omitempty"` // JSON serialized plugin log entries grouped by plugin name + Metadata *string `gorm:"type:text" json:"-"` // JSON serialized map[string]interface{} + IsLargePayloadRequest bool `gorm:"default:false" json:"is_large_payload_request"` + IsLargePayloadResponse bool `gorm:"default:false" json:"is_large_payload_response"` // Denormalized token fields for easier querying PromptTokens int `gorm:"default:0" json:"-"` @@ -156,6 +192,8 @@ type Log struct { SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"` TranscriptionInputParsed *schemas.TranscriptionInput `gorm:"-" json:"transcription_input,omitempty"` ImageGenerationInputParsed *schemas.ImageGenerationInput `gorm:"-" json:"image_generation_input,omitempty"` + ImageEditInputParsed *schemas.ImageEditInput `gorm:"-" json:"image_edit_input,omitempty"` + ImageVariationInputParsed *schemas.ImageVariationInput `gorm:"-" json:"image_variation_input,omitempty"` SpeechOutputParsed *schemas.BifrostSpeechResponse `gorm:"-" json:"speech_output,omitempty"` TranscriptionOutputParsed *schemas.BifrostTranscriptionResponse `gorm:"-" json:"transcription_output,omitempty"` ImageGenerationOutputParsed *schemas.BifrostImageGenerationResponse `gorm:"-" json:"image_generation_output,omitempty"` @@ -288,6 +326,22 @@ func (l *Log) SerializeFields() error { } } + if l.ImageEditInputParsed != nil { + if data, err := sonic.Marshal(l.ImageEditInputParsed); err != nil { + return err + } else { + l.ImageEditInput = string(data) + } + } + + if l.ImageVariationInputParsed != nil { + if data, err := sonic.Marshal(l.ImageVariationInputParsed); err != nil { + return err + } else { + l.ImageVariationInput = string(data) + } + } + if l.VideoGenerationInputParsed != nil { if data, err := sonic.Marshal(l.VideoGenerationInputParsed); err != nil { return err @@ -590,6 +644,18 @@ func (l *Log) DeserializeFields() error { } } + if l.ImageEditInput != "" { + if err := sonic.Unmarshal([]byte(l.ImageEditInput), &l.ImageEditInputParsed); err != nil { + l.ImageEditInputParsed = nil + } + } + + if l.ImageVariationInput != "" { + if err := sonic.Unmarshal([]byte(l.ImageVariationInput), &l.ImageVariationInputParsed); err != nil { + l.ImageVariationInputParsed = nil + } + } + if l.SpeechOutput != "" { if err := sonic.Unmarshal([]byte(l.SpeechOutput), &l.SpeechOutputParsed); err != nil { // Log error but don't fail the operation - initialize as nil @@ -1003,6 +1069,11 @@ func (l *Log) BuildContentSummary() string { parts = append(parts, l.ImageGenerationInputParsed.Prompt) } + // Add image edit input prompt + if l.ImageEditInputParsed != nil && l.ImageEditInputParsed.Prompt != "" { + parts = append(parts, l.ImageEditInputParsed.Prompt) + } + // Add video generation input prompt if l.VideoGenerationInputParsed != nil && l.VideoGenerationInputParsed.Prompt != "" { parts = append(parts, l.VideoGenerationInputParsed.Prompt) @@ -1159,6 +1230,87 @@ type ProviderLatencyHistogramResult struct { Providers []string `json:"providers"` } +// HistogramDimension represents a column that can be used as a grouping dimension in histograms +type HistogramDimension string + +const ( + DimensionProvider HistogramDimension = "provider" + DimensionTeam HistogramDimension = "team_id" + DimensionCustomer HistogramDimension = "customer_id" + DimensionUser HistogramDimension = "user_id" + DimensionBusinessUnit HistogramDimension = "business_unit_id" +) + +// ValidHistogramDimensions is the set of allowed dimension values +var ValidHistogramDimensions = map[HistogramDimension]bool{ + DimensionProvider: true, + DimensionTeam: true, + DimensionCustomer: true, + DimensionUser: true, + DimensionBusinessUnit: true, +} + +// Dimension-level histogram types (generic version of Provider histograms) + +// DimensionCostHistogramBucket represents a single time bucket for dimension-grouped cost data +type DimensionCostHistogramBucket struct { + Timestamp time.Time `json:"timestamp"` + TotalCost float64 `json:"total_cost"` + ByDimension map[string]float64 `json:"by_dimension"` +} + +// DimensionCostHistogramResult represents the dimension cost histogram query result +type DimensionCostHistogramResult struct { + Buckets []DimensionCostHistogramBucket `json:"buckets"` + BucketSizeSeconds int64 `json:"bucket_size_seconds"` + Dimension HistogramDimension `json:"dimension"` + DimensionValues []string `json:"dimension_values"` +} + +// DimensionTokenStats represents token statistics for a single dimension value +type DimensionTokenStats struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// DimensionTokenHistogramBucket represents a single time bucket for dimension-grouped token data +type DimensionTokenHistogramBucket struct { + Timestamp time.Time `json:"timestamp"` + ByDimension map[string]DimensionTokenStats `json:"by_dimension"` +} + +// DimensionTokenHistogramResult represents the dimension token histogram query result +type DimensionTokenHistogramResult struct { + Buckets []DimensionTokenHistogramBucket `json:"buckets"` + BucketSizeSeconds int64 `json:"bucket_size_seconds"` + Dimension HistogramDimension `json:"dimension"` + DimensionValues []string `json:"dimension_values"` +} + +// DimensionLatencyStats represents latency statistics for a single dimension value +type DimensionLatencyStats struct { + AvgLatency float64 `json:"avg_latency"` + P90Latency float64 `json:"p90_latency"` + P95Latency float64 `json:"p95_latency"` + P99Latency float64 `json:"p99_latency"` + TotalRequests int64 `json:"total_requests"` +} + +// DimensionLatencyHistogramBucket represents a single time bucket for dimension-grouped latency data +type DimensionLatencyHistogramBucket struct { + Timestamp time.Time `json:"timestamp"` + ByDimension map[string]DimensionLatencyStats `json:"by_dimension"` +} + +// DimensionLatencyHistogramResult represents the dimension latency histogram query result +type DimensionLatencyHistogramResult struct { + Buckets []DimensionLatencyHistogramBucket `json:"buckets"` + BucketSizeSeconds int64 `json:"bucket_size_seconds"` + Dimension HistogramDimension `json:"dimension"` + DimensionValues []string `json:"dimension_values"` +} + // MCPHistogramBucket represents a single time bucket for MCP tool call volume type MCPHistogramBucket struct { Timestamp time.Time `json:"timestamp"` @@ -1170,7 +1322,7 @@ type MCPHistogramBucket struct { // MCPHistogramResult represents the MCP tool call volume histogram query result type MCPHistogramResult struct { Buckets []MCPHistogramBucket `json:"buckets"` - BucketSizeSeconds int64 `json:"bucket_size_seconds"` + BucketSizeSeconds int64 `json:"bucket_size_seconds"` } // MCPCostHistogramBucket represents a single time bucket for MCP cost data @@ -1228,3 +1380,30 @@ type ModelRankingWithTrend struct { type ModelRankingResult struct { Rankings []ModelRankingWithTrend `json:"rankings"` } + +// UserRankingEntry represents a single user's usage statistics. +type UserRankingEntry struct { + UserID string `json:"user_id"` + TotalRequests int64 `json:"total_requests"` + TotalTokens int64 `json:"total_tokens"` + TotalCost float64 `json:"total_cost"` +} + +// UserRankingTrend represents the percentage change compared to the previous period. +type UserRankingTrend struct { + HasPreviousPeriod bool `json:"has_previous_period"` + RequestsTrend float64 `json:"requests_trend"` + TokensTrend float64 `json:"tokens_trend"` + CostTrend float64 `json:"cost_trend"` +} + +// UserRankingWithTrend combines ranking entry with trend data. +type UserRankingWithTrend struct { + UserRankingEntry + Trend UserRankingTrend `json:"trend"` +} + +// UserRankingResult is the response for the user rankings endpoint. +type UserRankingResult struct { + Rankings []UserRankingWithTrend `json:"rankings"` +} diff --git a/framework/modelcatalog/capabilities_test.go b/framework/modelcatalog/capabilities_test.go index 4d15f75191..3d188f775d 100644 --- a/framework/modelcatalog/capabilities_test.go +++ b/framework/modelcatalog/capabilities_test.go @@ -185,13 +185,17 @@ func TestGetModelCapabilityEntryForModel_PrefersLiteralMatchOverAliasFamily(t *t func TestCapabilityFieldsRoundTripThroughPricingConversions(t *testing.T) { modality := "text" + inputCost := float64(1) + outputCost := float64(2) entry := PricingEntry{ - BaseModel: "gpt-4o", - Provider: "openai", - Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, - ContextLength: capabilityIntPtr(128000), + BaseModel: "gpt-4o", + Provider: "openai", + Mode: "chat", + PricingOptions: PricingOptions{ + InputCostPerToken: &inputCost, + OutputCostPerToken: &outputCost, + }, + ContextLength: capabilityIntPtr(128000), MaxInputTokens: capabilityIntPtr(64000), MaxOutputTokens: capabilityIntPtr(16000), Architecture: &schemas.Architecture{ diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 99e72242aa..2c03381927 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -41,10 +41,13 @@ type ModelCatalog struct { pricingData map[string]configstoreTables.TableModelPricing mu sync.RWMutex - // Provider-level pricing overrides are maintained separately to avoid contention - // with pricing cache rebuilds. - compiledOverrides map[schemas.ModelProvider][]compiledProviderPricingOverride - overridesMu sync.RWMutex + // rawOverrides is the canonical list of all active overrides. It exists solely + // to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride + // iterate over it to rebuild the list, then derive customPricing from it. + // customPricing is the actual lookup structure used at query time. + rawOverrides []PricingOverride + customPricing *customPricingData + overridesMu sync.RWMutex modelPool map[schemas.ModelProvider][]string unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering @@ -70,9 +73,13 @@ type PricingEntry struct { MaxOutputTokens *int `json:"max_output_tokens,omitempty"` Architecture *schemas.Architecture `json:"architecture,omitempty"` + PricingOptions +} + +type PricingOptions struct { // Costs - Text - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` + InputCostPerToken *float64 `json:"input_cost_per_token,omitempty"` + OutputCostPerToken *float64 `json:"output_cost_per_token,omitempty"` InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority,omitempty"` @@ -202,7 +209,6 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto configStore: configStore, logger: logger, pricingData: make(map[string]configstoreTables.TableModelPricing), - compiledOverrides: make(map[schemas.ModelProvider][]compiledProviderPricingOverride), modelPool: make(map[schemas.ModelProvider][]string), unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: make(map[string]string), @@ -276,6 +282,10 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto // Populate model pool with normalized providers from pricing data mc.populateModelPoolFromPricingData() + if err := mc.loadPricingOverridesFromStore(ctx); err != nil { + return nil, fmt.Errorf("failed to load pricing overrides: %w", err) + } + // Start background sync worker mc.syncCtx, mc.syncCancel = context.WithCancel(ctx) mc.startSyncWorker(mc.syncCtx) @@ -346,6 +356,10 @@ func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error { // Rebuild model pool from updated pricing data mc.populateModelPoolFromPricingData() + if err := mc.loadPricingOverridesFromStore(ctx); err != nil { + return fmt.Errorf("failed to load pricing overrides: %w", err) + } + // Also sync model parameters if err := mc.syncModelParameters(ctx); err != nil { mc.logger.Warn("failed to sync model parameters during force reload: %v", err) @@ -622,8 +636,9 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid // - allowedModels: List of allowed model names (can be empty, can include provider prefixes) // // Behavior: -// - If allowedModels is empty: Uses model catalog to check if provider supports the model +// - If allowedModels is ["*"]: Uses model catalog to check if provider supports the model // (delegates to GetProvidersForModel which handles all cross-provider logic) +// - If allowedModels is empty ([]): Deny-by-default β€” returns false for any provider/model pair // - If allowedModels is not empty: Checks if model matches any entry in the list // Provider-specific validation: // - Direct matches: "gpt-4o" in allowedModels for any provider @@ -636,10 +651,14 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid // // Examples: // -// // Empty allowedModels - uses catalog -// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{}) +// // Wildcard allowedModels - uses catalog to check provider support +// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{"*"}) // // Returns: true (catalog knows openrouter has "anthropic/claude-3-5-sonnet") // +// // Empty allowedModels - deny all (deny-by-default) +// mc.IsModelAllowedForProvider("openrouter", "claude-3-5-sonnet", []string{}) +// // Returns: false (no models are permitted) +// // // Explicit allowedModels with prefix - validates against catalog // mc.IsModelAllowedForProvider("openrouter", "gpt-4o", []string{"openai/gpt-4o"}) // // Returns: true (openrouter's catalog contains "openai/gpt-4o" AND model part is "gpt-4o") @@ -651,13 +670,16 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid // // Explicit allowedModels without prefix // mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"}) // // Returns: true (direct match) -func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels []string) bool { - // Case 1: Empty allowedModels = use catalog to determine support - // This leverages GetProvidersForModel which already handles all cross-provider logic - if len(allowedModels) == 0 { +func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels schemas.WhiteList) bool { + // Case 1: ["*"] = allow all models; use catalog to determine support + // Empty allowedModels = deny all (fail-safe deny-by-default) + if allowedModels.IsUnrestricted() { supportedProviders := mc.GetProvidersForModel(model) return slices.Contains(supportedProviders, provider) } + if allowedModels.IsEmpty() { + return false + } // Case 2: Explicit allowedModels = check if model matches any entry // Get provider's catalog models for validation of prefixed entries @@ -752,7 +774,7 @@ func (mc *ModelCatalog) DeleteModelDataForProvider(provider schemas.ModelProvide } // UpsertModelDataForProvider upserts model data for a given provider -func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model, deniedModels []schemas.Model) { +func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvider, modelData *schemas.BifrostListModelsResponse, allowedModels []schemas.Model) { if modelData == nil { return } @@ -781,7 +803,7 @@ func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvide } } // If modelData is empty, then we allow all models - if len(modelData.Data) == 0 && len(allowedModels) == 0 && len(deniedModels) == 0 { + if len(modelData.Data) == 0 && len(allowedModels) == 0 { mc.modelPool[provider] = providerModels return } @@ -814,15 +836,7 @@ func (mc *ModelCatalog) UpsertModelDataForProvider(provider schemas.ModelProvide } if len(allowedModels) == 0 { - deniedSet := make(map[string]struct{}, len(deniedModels)) - for _, d := range deniedModels { - _, modelName := schemas.ParseModelString(d.ID, "") - deniedSet[modelName] = struct{}{} - } for _, model := range providerModels { - if _, denied := deniedSet[model]; denied { - continue - } if !seenModels[model] { seenModels[model] = true finalModelList = append(finalModelList, model) @@ -925,6 +939,79 @@ func (mc *ModelCatalog) refineNestedProviderModel(provider schemas.ModelProvider } } +// SetPricingOverrides replaces the full in-memory pricing override set. +func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error { + seen := make(map[string]int, len(rows)) + overrides := make([]PricingOverride, 0, len(rows)) + for i := range rows { + o, err := convertTablePricingOverrideToPricingOverride(&rows[i]) + if err != nil { + return err + } + if idx, exists := seen[o.ID]; exists { + overrides[idx] = o // last entry wins for duplicate IDs + } else { + seen[o.ID] = len(overrides) + overrides = append(overrides, o) + } + } + mc.overridesMu.Lock() + mc.rawOverrides = overrides + mc.customPricing = buildCustomPricingData(overrides) + mc.overridesMu.Unlock() + return nil +} + +// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single +// operation, rebuilding the lookup map only once at the end. +func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error { + // Deduplicate the input batch by ID (last entry wins) and build the + // incoming set for O(1) lookup when filtering existing rawOverrides. + seenIncoming := make(map[string]int, len(rows)) + overrides := make([]PricingOverride, 0, len(rows)) + for _, row := range rows { + o, err := convertTablePricingOverrideToPricingOverride(row) + if err != nil { + return err + } + if idx, exists := seenIncoming[o.ID]; exists { + overrides[idx] = o // last entry wins for duplicate IDs + } else { + seenIncoming[o.ID] = len(overrides) + overrides = append(overrides, o) + } + } + + mc.overridesMu.Lock() + defer mc.overridesMu.Unlock() + + updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides)) + for _, o := range mc.rawOverrides { + if _, replacing := seenIncoming[o.ID]; !replacing { + updated = append(updated, o) + } + } + updated = append(updated, overrides...) + mc.rawOverrides = updated + mc.customPricing = buildCustomPricingData(updated) + return nil +} + +// DeletePricingOverride removes a pricing override by ID. +func (mc *ModelCatalog) DeletePricingOverride(id string) { + mc.overridesMu.Lock() + defer mc.overridesMu.Unlock() + + updated := make([]PricingOverride, 0, len(mc.rawOverrides)) + for _, o := range mc.rawOverrides { + if o.ID != id { + updated = append(updated, o) + } + } + mc.rawOverrides = updated + mc.customPricing = buildCustomPricingData(updated) +} + // IsTextCompletionSupported checks if a model supports text completion for the given provider. // Returns true if the model has pricing data for text completion ("text_completion"), // false otherwise. This is used by the litellmcompat plugin to determine whether to @@ -1019,7 +1106,6 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: baseModelIndex, pricingData: make(map[string]configstoreTables.TableModelPricing), - compiledOverrides: make(map[schemas.ModelProvider][]compiledProviderPricingOverride), done: make(chan struct{}), } } diff --git a/framework/modelcatalog/main_test.go b/framework/modelcatalog/main_test.go index 3b7e67e702..8d313493cd 100644 --- a/framework/modelcatalog/main_test.go +++ b/framework/modelcatalog/main_test.go @@ -17,10 +17,9 @@ func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex baseModelIndex = make(map[string]string) } return &ModelCatalog{ - modelPool: modelPool, - baseModelIndex: baseModelIndex, - pricingData: make(map[string]configstoreTables.TableModelPricing), - compiledOverrides: make(map[schemas.ModelProvider][]compiledProviderPricingOverride), + modelPool: modelPool, + baseModelIndex: baseModelIndex, + pricingData: make(map[string]configstoreTables.TableModelPricing), } } diff --git a/framework/modelcatalog/overrides.go b/framework/modelcatalog/overrides.go index 6eef025a48..f284a80a8e 100644 --- a/framework/modelcatalog/overrides.go +++ b/framework/modelcatalog/overrides.go @@ -1,279 +1,456 @@ package modelcatalog import ( + "context" "fmt" - "regexp" + "sort" "strings" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) -type compiledProviderPricingOverride struct { - override schemas.ProviderPricingOverride - regex *regexp.Regexp - requestModes map[string]struct{} - hasRequestFilter bool - literalChars int - order int +// PricingLookupScopes carries the runtime identifiers used to resolve scoped +// pricing overrides during cost calculation. +type PricingLookupScopes struct { + VirtualKeyID string + SelectedKeyID string + Provider string } -func (mc *ModelCatalog) SetProviderPricingOverrides(provider schemas.ModelProvider, overrides []schemas.ProviderPricingOverride) error { - compiled := make([]compiledProviderPricingOverride, 0, len(overrides)) - for i := range overrides { - item, err := compileProviderPricingOverride(i, overrides[i]) - if err != nil { - return fmt.Errorf("invalid pricing override for provider %s at index %d: %w", provider, i, err) - } - compiled = append(compiled, item) - } - - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - if len(compiled) == 0 { - delete(mc.compiledOverrides, provider) +// PricingLookupScopesFromContext builds a PricingLookupScopes from a BifrostContext. +// It reads the governance virtual key ID (not the raw VK token) and the selected key ID. +// provider should be the provider name string (e.g. "openai"), pass "" if unavailable. +// Returns nil only when ctx is nil. An empty scopes value is still returned when all fields +// are empty so that global-scope overrides are always evaluated. +// DO NOT USE THIS FUNCTION IN A GO ROUTINE. This is because it reads from ctx which is cancelled when the request ends. +// Better to call it in PostHooks synchronously and then pass the scopes object to the pricing manager. +// Only use this in go routines when you know for sure that the request will not end before the go routine completes. +func PricingLookupScopesFromContext(ctx *schemas.BifrostContext, provider string) *PricingLookupScopes { + if ctx == nil { return nil } - mc.compiledOverrides[provider] = compiled - return nil + virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) + selectedKeyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string) + return &PricingLookupScopes{ + VirtualKeyID: virtualKeyID, + SelectedKeyID: selectedKeyID, + Provider: provider, + } } -func (mc *ModelCatalog) DeleteProviderPricingOverrides(provider schemas.ModelProvider) { - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - delete(mc.compiledOverrides, provider) -} +// ScopeKind identifies which governance scope an override applies to. +type ScopeKind string -func (mc *ModelCatalog) applyPricingOverrides(provider schemas.ModelProvider, model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing) configstoreTables.TableModelPricing { - mc.overridesMu.RLock() - overrides := mc.compiledOverrides[provider] - mc.overridesMu.RUnlock() - if len(overrides) == 0 { - return pricing - } +const ( + ScopeKindGlobal ScopeKind = "global" + ScopeKindProvider ScopeKind = "provider" + ScopeKindProviderKey ScopeKind = "provider_key" + ScopeKindVirtualKey ScopeKind = "virtual_key" + ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider" + ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key" +) - modelCandidates := []string{model} - mode := normalizeRequestType(requestType) - best := selectBestOverride(overrides, modelCandidates, mode) - if best == nil { - return pricing - } +// MatchType controls how an override pattern is matched against model names. +type MatchType string - return patchPricing(pricing, best.override) +const ( + MatchTypeExact MatchType = "exact" + MatchTypeWildcard MatchType = "wildcard" +) + +// PricingOverride describes a scoped pricing override shared across config storage, +// model catalog compilation, and governance APIs. +type PricingOverride struct { + ID string `json:"id"` + Name string `json:"name"` + ScopeKind ScopeKind `json:"scope_kind"` + VirtualKeyID *string `json:"virtual_key_id,omitempty"` + ProviderID *string `json:"provider_id,omitempty"` + ProviderKeyID *string `json:"provider_key_id,omitempty"` + MatchType MatchType `json:"match_type"` + Pattern string `json:"pattern"` + RequestTypes []schemas.RequestType `json:"request_types,omitempty"` + Options PricingOptions `json:"options"` } -func compileProviderPricingOverride(order int, override schemas.ProviderPricingOverride) (compiledProviderPricingOverride, error) { - pattern := strings.TrimSpace(override.ModelPattern) - if pattern == "" { - return compiledProviderPricingOverride{}, fmt.Errorf("model_pattern cannot be empty") - } +// customPricingEntry is a single flattened override ready for lookup. +type customPricingEntry struct { + id string + scopeKind ScopeKind + virtualKeyID string + providerID string + providerKeyID string + pattern string // exact model name, or wildcard prefix (trailing * stripped) + wildcard bool + requestModes map[string]struct{} // always non-nil for valid overrides + options PricingOptions +} - result := compiledProviderPricingOverride{ - override: override, - requestModes: make(map[string]struct{}), - order: order, - } - result.override.ModelPattern = pattern +// customPricingData is the in-memory lookup structure for pricing overrides. +// Exact matches are indexed by model name; wildcards are a flat slice. +type customPricingData struct { + exact map[string][]customPricingEntry + wildcard []customPricingEntry +} - switch override.MatchType { - case schemas.PricingOverrideMatchExact: - result.literalChars = len(pattern) - case schemas.PricingOverrideMatchWildcard: - if !strings.Contains(pattern, "*") { - return compiledProviderPricingOverride{}, fmt.Errorf("wildcard model_pattern must contain '*'") +// IsValid validates the shared pricing override contract before persistence or runtime use. +// +// Input: override β€” the PricingOverride to validate (receiver). +// Output: error β€” non-nil if any scope, pattern, or request-type constraint is violated. +func (override *PricingOverride) IsValid() error { + if err := override.validateScopeKind(); err != nil { + return err + } + if err := override.validatePattern(); err != nil { + return err + } + return override.validateRequestTypes() +} + +// validateScopeKind validates the scope identifiers required by override.ScopeKind. +// +// Input: override β€” receiver; ScopeKind and the three optional ID fields are inspected. +// Output: error β€” non-nil when required identifiers are absent or forbidden ones are present. +func (override *PricingOverride) validateScopeKind() error { + switch override.ScopeKind { + case ScopeKindGlobal: + if override.VirtualKeyID != nil || override.ProviderID != nil || override.ProviderKeyID != nil { + return fmt.Errorf("global scope_kind must not include scope identifiers") } - result.literalChars = len(strings.ReplaceAll(pattern, "*", "")) - case schemas.PricingOverrideMatchRegex: - re, err := regexp.Compile(pattern) - if err != nil { - return compiledProviderPricingOverride{}, fmt.Errorf("invalid regex model_pattern: %w", err) + case ScopeKindProvider: + if override.ProviderID == nil { + return fmt.Errorf("provider_id is required for provider scope_kind") } - result.regex = re - result.literalChars = len(pattern) - default: - return compiledProviderPricingOverride{}, fmt.Errorf("unsupported match_type: %s", override.MatchType) - } - - if len(override.RequestTypes) > 0 { - result.hasRequestFilter = true - for _, requestType := range override.RequestTypes { - mode := normalizeRequestType(requestType) - if mode == "unknown" { - return compiledProviderPricingOverride{}, fmt.Errorf("unsupported request_type: %s", requestType) - } - result.requestModes[mode] = struct{}{} + if override.VirtualKeyID != nil || override.ProviderKeyID != nil { + return fmt.Errorf("provider scope_kind only supports provider_id") + } + case ScopeKindProviderKey: + if override.ProviderKeyID == nil { + return fmt.Errorf("provider_key_id is required for provider_key scope_kind") + } + if override.VirtualKeyID != nil || override.ProviderID != nil { + return fmt.Errorf("provider_key scope_kind only supports provider_key_id") + } + case ScopeKindVirtualKey: + if override.VirtualKeyID == nil { + return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind") + } + if override.ProviderID != nil || override.ProviderKeyID != nil { + return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id") + } + case ScopeKindVirtualKeyProvider: + if override.VirtualKeyID == nil || override.ProviderID == nil { + return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind") } + if override.ProviderKeyID != nil { + return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id") + } + case ScopeKindVirtualKeyProviderKey: + if override.VirtualKeyID == nil || override.ProviderID == nil || override.ProviderKeyID == nil { + return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind") + } + default: + return fmt.Errorf("unsupported scope_kind %q", override.ScopeKind) } - - return result, nil + return nil } -func selectBestOverride(overrides []compiledProviderPricingOverride, modelCandidates []string, mode string) *compiledProviderPricingOverride { - var best *compiledProviderPricingOverride - for i := range overrides { - candidate := &overrides[i] - if candidate.hasRequestFilter { - if _, ok := candidate.requestModes[mode]; !ok { - continue - } +// validatePattern checks that Pattern is non-empty and consistent with MatchType. +// +// Input: override β€” receiver; Pattern and MatchType are inspected. +// Output: error β€” non-nil when the pattern is empty, contains a wildcard for exact mode, +// +// or does not end with a single trailing "*" for wildcard mode. +func (override *PricingOverride) validatePattern() error { + pattern := strings.TrimSpace(override.Pattern) + if pattern == "" { + return fmt.Errorf("pattern is required") + } + switch override.MatchType { + case MatchTypeExact: + if strings.Contains(pattern, "*") { + return fmt.Errorf("exact match pattern must not contain wildcards") } - if !matchesAnyModel(candidate, modelCandidates) { - continue + case MatchTypeWildcard: + if !strings.HasSuffix(pattern, "*") { + return fmt.Errorf("wildcard pattern must end with *") } - if isBetterOverride(candidate, best) { - best = candidate + if strings.Count(pattern, "*") != 1 { + return fmt.Errorf("wildcard pattern must contain exactly one trailing *") } + default: + return fmt.Errorf("unsupported match_type %q", override.MatchType) } - return best + return nil } -func matchesAnyModel(override *compiledProviderPricingOverride, modelCandidates []string) bool { - for _, model := range modelCandidates { - if matchesModel(override, model) { - return true +// validateRequestTypes checks that RequestTypes is non-empty and that every entry is a +// supported base request type. Stream variants (e.g. chat_completion_stream) are rejected β€” +// the base type (chat_completion) already covers both streaming and non-streaming requests. +// +// Input: override β€” receiver; RequestTypes slice is inspected. +// Output: error β€” non-nil if RequestTypes is empty, or contains an unsupported or stream variant. +func (override *PricingOverride) validateRequestTypes() error { + if len(override.RequestTypes) == 0 { + return fmt.Errorf("request_types is required and must contain at least one value") + } + for _, rt := range override.RequestTypes { + if normalizeStreamRequestType(rt) != rt { + return fmt.Errorf("unsupported request_type %q: use the base type (e.g. %q covers both streaming and non-streaming)", rt, normalizeStreamRequestType(rt)) + } + if normalizeRequestType(rt) == "unknown" { + return fmt.Errorf("unsupported request_type %q", rt) } } - return false + return nil } -func matchesModel(override *compiledProviderPricingOverride, model string) bool { - switch override.override.MatchType { - case schemas.PricingOverrideMatchExact: - return model == override.override.ModelPattern - case schemas.PricingOverrideMatchWildcard: - return wildcardMatch(override.override.ModelPattern, model) - case schemas.PricingOverrideMatchRegex: - return override.regex != nil && override.regex.MatchString(model) - default: - return false +// matchesScope reports whether the entry's governance scope matches the runtime identifiers. +// +// Input: scopes β€” runtime VirtualKeyID, SelectedKeyID, and Provider to match against. +// Output: bool β€” true when the entry's scope kind and stored IDs align with scopes. +func (e *customPricingEntry) matchesScope(scopes PricingLookupScopes) bool { + switch e.scopeKind { + case ScopeKindGlobal: + return true + case ScopeKindProvider: + return e.providerID == scopes.Provider + case ScopeKindProviderKey: + return e.providerKeyID == scopes.SelectedKeyID + case ScopeKindVirtualKey: + return e.virtualKeyID == scopes.VirtualKeyID + case ScopeKindVirtualKeyProvider: + return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider + case ScopeKindVirtualKeyProviderKey: + return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider && e.providerKeyID == scopes.SelectedKeyID } + return false } -func overridePriority(matchType schemas.PricingOverrideMatchType) int { - switch matchType { - case schemas.PricingOverrideMatchExact: - return 0 - case schemas.PricingOverrideMatchWildcard: - return 1 - case schemas.PricingOverrideMatchRegex: - return 2 - default: - return 3 - } +// matchesMode reports whether the entry applies to the given normalized request mode. +// +// Input: mode β€” normalized request type string (e.g. "chat", "embedding"). +// Output: bool β€” true when requestModes contains mode. +func (e *customPricingEntry) matchesMode(mode string) bool { + _, ok := e.requestModes[mode] + return ok } -func isBetterOverride(candidate, best *compiledProviderPricingOverride) bool { - if best == nil { - return true - } - - candidatePriority := overridePriority(candidate.override.MatchType) - bestPriority := overridePriority(best.override.MatchType) - if candidatePriority != bestPriority { - return candidatePriority < bestPriority +// resolve walks the 6-scope priority hierarchy and returns the first matching +// pricing patch for the given model, request mode, and runtime scopes. +// +// Input: model β€” exact model name being priced. +// +// mode β€” normalized request type string (e.g. "chat", "embedding"). +// scopes β€” runtime governance identifiers used to narrow the scope search. +// +// Output: *PricingOptions β€” pointer to the first matching override's options, or nil if none match. +func (c *customPricingData) resolve(model, mode string, scopes PricingLookupScopes) *PricingOptions { + for _, scopeKind := range scopePriorityOrder(scopes) { + for i := range c.exact[model] { + e := &c.exact[model][i] + if e.scopeKind == scopeKind && e.matchesScope(scopes) && e.matchesMode(mode) { + return &e.options + } + } + for i := range c.wildcard { + e := &c.wildcard[i] + if e.scopeKind == scopeKind && e.matchesScope(scopes) && strings.HasPrefix(model, e.pattern) && e.matchesMode(mode) { + return &e.options + } + } } + return nil +} - if candidate.hasRequestFilter != best.hasRequestFilter { - return candidate.hasRequestFilter - } +// scopePriorityOrder returns scope kinds in most-specific-first order, +// skipping scopes that can't match given the available runtime identifiers. +// +// Input: scopes β€” runtime governance identifiers; empty fields cause the corresponding scope kinds to be omitted. +// Output: []ScopeKind β€” ordered list from most-specific (VirtualKeyProviderKey) to least-specific (Global). +func scopePriorityOrder(scopes PricingLookupScopes) []ScopeKind { + order := make([]ScopeKind, 0, 6) + if scopes.VirtualKeyID != "" && scopes.Provider != "" && scopes.SelectedKeyID != "" { + order = append(order, ScopeKindVirtualKeyProviderKey) + } + if scopes.VirtualKeyID != "" && scopes.Provider != "" { + order = append(order, ScopeKindVirtualKeyProvider) + } + if scopes.VirtualKeyID != "" { + order = append(order, ScopeKindVirtualKey) + } + if scopes.SelectedKeyID != "" { + order = append(order, ScopeKindProviderKey) + } + if scopes.Provider != "" { + order = append(order, ScopeKindProvider) + } + order = append(order, ScopeKindGlobal) + return order +} - if candidate.literalChars != best.literalChars { - return candidate.literalChars > best.literalChars +// buildCustomPricingData constructs a customPricingData lookup structure from a raw override slice. +// +// Input: overrides β€” slice of validated PricingOverride records loaded from the config store. +// Output: *customPricingData β€” ready-to-query structure with exact and wildcard indexes populated. +func buildCustomPricingData(overrides []PricingOverride) *customPricingData { + data := &customPricingData{ + exact: make(map[string][]customPricingEntry, len(overrides)), + } + for _, o := range overrides { + entry := customPricingEntry{ + id: o.ID, + scopeKind: o.ScopeKind, + options: o.Options, + } + if o.VirtualKeyID != nil { + entry.virtualKeyID = *o.VirtualKeyID + } + if o.ProviderID != nil { + entry.providerID = *o.ProviderID + } + if o.ProviderKeyID != nil { + entry.providerKeyID = *o.ProviderKeyID + } + entry.requestModes = make(map[string]struct{}, len(o.RequestTypes)) + for _, rt := range o.RequestTypes { + entry.requestModes[normalizeRequestType(rt)] = struct{}{} + } + pattern := strings.TrimSpace(o.Pattern) + switch o.MatchType { + case MatchTypeExact: + entry.pattern = pattern + data.exact[pattern] = append(data.exact[pattern], entry) + case MatchTypeWildcard: + entry.pattern = strings.TrimSuffix(pattern, "*") + entry.wildcard = true + data.wildcard = append(data.wildcard, entry) + } } - - return candidate.order < best.order + // Sort wildcards by descending prefix length so more-specific patterns (e.g. "gpt-4*") + // are checked before broader ones (e.g. "gpt-*"), making precedence deterministic. + sort.Slice(data.wildcard, func(i, j int) bool { + return len(data.wildcard[i].pattern) > len(data.wildcard[j].pattern) + }) + return data } -func wildcardMatch(pattern, model string) bool { - parts := strings.Split(pattern, "*") - if len(parts) == 1 { - return model == pattern - } +// applyPricingOverrides resolves any active scoped pricing override for the given model +// and request type, then patches the catalog base pricing with the override values. +// It returns the original pricing unchanged when no custom pricing tree is loaded or +// when the request type cannot be mapped to a known pricing mode. +// +// Input: model β€” exact model name being priced. +// +// requestType β€” the request type used to derive the pricing mode. +// pricing β€” base pricing row from the catalog to patch. +// scopes β€” runtime governance identifiers used to narrow the override scope. +// +// Output: TableModelPricing β€” patched pricing row, or pricing unchanged if no override matches. +// bool β€” true when an override was applied, false otherwise. +func (mc *ModelCatalog) applyPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) (configstoreTables.TableModelPricing, bool) { + mc.overridesMu.RLock() + custom := mc.customPricing + mc.overridesMu.RUnlock() - remaining := model - if parts[0] != "" { - if !strings.HasPrefix(remaining, parts[0]) { - return false - } - remaining = remaining[len(parts[0]):] + if custom == nil { + return pricing, false } - for i := 1; i < len(parts)-1; i++ { - part := parts[i] - if part == "" { - continue - } - index := strings.Index(remaining, part) - if index < 0 { - return false - } - remaining = remaining[index+len(part):] + mode := normalizeRequestType(requestType) + if mode == "unknown" { + return pricing, false } - last := parts[len(parts)-1] - if last == "" { - return true + if patch := custom.resolve(model, mode, scopes); patch != nil { + return patchPricing(pricing, *patch), true } - return strings.HasSuffix(remaining, last) + return pricing, false } -func patchPricing(pricing configstoreTables.TableModelPricing, override schemas.ProviderPricingOverride) configstoreTables.TableModelPricing { +// patchPricing applies override values onto a copy of the base pricing row. +// For all fields, a non-nil override pointer replaces the corresponding destination value; +// a nil override leaves the base value intact. +// The original pricing row is never modified; a patched copy is always returned. +// +// Input: pricing β€” base pricing row from the catalog. +// +// override β€” pricing options sourced from the matched override entry. +// +// Output: TableModelPricing β€” shallow copy of pricing with override fields applied. +func patchPricing(pricing configstoreTables.TableModelPricing, override PricingOptions) configstoreTables.TableModelPricing { patched := pricing - if override.InputCostPerToken != nil { - patched.InputCostPerToken = *override.InputCostPerToken - } - if override.OutputCostPerToken != nil { - patched.OutputCostPerToken = *override.OutputCostPerToken - } - if override.InputCostPerVideoPerSecond != nil { - patched.InputCostPerVideoPerSecond = override.InputCostPerVideoPerSecond - } - if override.InputCostPerAudioPerSecond != nil { - patched.InputCostPerAudioPerSecond = override.InputCostPerAudioPerSecond - } - if override.InputCostPerTokenAbove200kTokens != nil { - patched.InputCostPerTokenAbove200kTokens = override.InputCostPerTokenAbove200kTokens - } - if override.OutputCostPerTokenAbove200kTokens != nil { - patched.OutputCostPerTokenAbove200kTokens = override.OutputCostPerTokenAbove200kTokens - } - if override.CacheCreationInputTokenCostAbove200kTokens != nil { - patched.CacheCreationInputTokenCostAbove200kTokens = override.CacheCreationInputTokenCostAbove200kTokens - } - if override.CacheReadInputTokenCostAbove200kTokens != nil { - patched.CacheReadInputTokenCostAbove200kTokens = override.CacheReadInputTokenCostAbove200kTokens - } - if override.CacheReadInputTokenCost != nil { - patched.CacheReadInputTokenCost = override.CacheReadInputTokenCost - } - if override.CacheCreationInputTokenCost != nil { - patched.CacheCreationInputTokenCost = override.CacheCreationInputTokenCost - } - if override.InputCostPerTokenBatches != nil { - patched.InputCostPerTokenBatches = override.InputCostPerTokenBatches - } - if override.OutputCostPerTokenBatches != nil { - patched.OutputCostPerTokenBatches = override.OutputCostPerTokenBatches - } - if override.InputCostPerImage != nil { - patched.InputCostPerImage = override.InputCostPerImage - } - if override.OutputCostPerImage != nil { - patched.OutputCostPerImage = override.OutputCostPerImage - } - if override.OutputCostPerImageLowQuality != nil { - patched.OutputCostPerImageLowQuality = override.OutputCostPerImageLowQuality - } - if override.OutputCostPerImageMediumQuality != nil { - patched.OutputCostPerImageMediumQuality = override.OutputCostPerImageMediumQuality + for _, field := range []struct { + dst **float64 + src *float64 + }{ + {dst: &patched.InputCostPerToken, src: override.InputCostPerToken}, + {dst: &patched.OutputCostPerToken, src: override.OutputCostPerToken}, + {dst: &patched.InputCostPerTokenPriority, src: override.InputCostPerTokenPriority}, + {dst: &patched.OutputCostPerTokenPriority, src: override.OutputCostPerTokenPriority}, + {dst: &patched.InputCostPerVideoPerSecond, src: override.InputCostPerVideoPerSecond}, + {dst: &patched.OutputCostPerVideoPerSecond, src: override.OutputCostPerVideoPerSecond}, + {dst: &patched.OutputCostPerSecond, src: override.OutputCostPerSecond}, + {dst: &patched.InputCostPerAudioPerSecond, src: override.InputCostPerAudioPerSecond}, + {dst: &patched.InputCostPerSecond, src: override.InputCostPerSecond}, + {dst: &patched.InputCostPerAudioToken, src: override.InputCostPerAudioToken}, + {dst: &patched.OutputCostPerAudioToken, src: override.OutputCostPerAudioToken}, + {dst: &patched.InputCostPerCharacter, src: override.InputCostPerCharacter}, + {dst: &patched.InputCostPerTokenAbove128kTokens, src: override.InputCostPerTokenAbove128kTokens}, + {dst: &patched.InputCostPerImageAbove128kTokens, src: override.InputCostPerImageAbove128kTokens}, + {dst: &patched.InputCostPerVideoPerSecondAbove128kTokens, src: override.InputCostPerVideoPerSecondAbove128kTokens}, + {dst: &patched.InputCostPerAudioPerSecondAbove128kTokens, src: override.InputCostPerAudioPerSecondAbove128kTokens}, + {dst: &patched.OutputCostPerTokenAbove128kTokens, src: override.OutputCostPerTokenAbove128kTokens}, + {dst: &patched.InputCostPerTokenAbove200kTokens, src: override.InputCostPerTokenAbove200kTokens}, + {dst: &patched.OutputCostPerTokenAbove200kTokens, src: override.OutputCostPerTokenAbove200kTokens}, + {dst: &patched.CacheCreationInputTokenCostAbove200kTokens, src: override.CacheCreationInputTokenCostAbove200kTokens}, + {dst: &patched.CacheReadInputTokenCostAbove200kTokens, src: override.CacheReadInputTokenCostAbove200kTokens}, + {dst: &patched.CacheReadInputTokenCost, src: override.CacheReadInputTokenCost}, + {dst: &patched.CacheCreationInputTokenCost, src: override.CacheCreationInputTokenCost}, + {dst: &patched.CacheCreationInputTokenCostAbove1hr, src: override.CacheCreationInputTokenCostAbove1hr}, + {dst: &patched.CacheCreationInputTokenCostAbove1hrAbove200kTokens, src: override.CacheCreationInputTokenCostAbove1hrAbove200kTokens}, + {dst: &patched.CacheCreationInputAudioTokenCost, src: override.CacheCreationInputAudioTokenCost}, + {dst: &patched.CacheReadInputTokenCostPriority, src: override.CacheReadInputTokenCostPriority}, + {dst: &patched.InputCostPerTokenBatches, src: override.InputCostPerTokenBatches}, + {dst: &patched.OutputCostPerTokenBatches, src: override.OutputCostPerTokenBatches}, + {dst: &patched.InputCostPerImageToken, src: override.InputCostPerImageToken}, + {dst: &patched.OutputCostPerImageToken, src: override.OutputCostPerImageToken}, + {dst: &patched.InputCostPerImage, src: override.InputCostPerImage}, + {dst: &patched.OutputCostPerImage, src: override.OutputCostPerImage}, + {dst: &patched.InputCostPerPixel, src: override.InputCostPerPixel}, + {dst: &patched.OutputCostPerPixel, src: override.OutputCostPerPixel}, + {dst: &patched.OutputCostPerImagePremiumImage, src: override.OutputCostPerImagePremiumImage}, + {dst: &patched.OutputCostPerImageAbove512x512Pixels, src: override.OutputCostPerImageAbove512x512Pixels}, + {dst: &patched.OutputCostPerImageAbove512x512PixelsPremium, src: override.OutputCostPerImageAbove512x512PixelsPremium}, + {dst: &patched.OutputCostPerImageAbove1024x1024Pixels, src: override.OutputCostPerImageAbove1024x1024Pixels}, + {dst: &patched.OutputCostPerImageAbove1024x1024PixelsPremium, src: override.OutputCostPerImageAbove1024x1024PixelsPremium}, + {dst: &patched.OutputCostPerImageAbove2048x2048Pixels, src: override.OutputCostPerImageAbove2048x2048Pixels}, + {dst: &patched.OutputCostPerImageAbove4096x4096Pixels, src: override.OutputCostPerImageAbove4096x4096Pixels}, + {dst: &patched.CacheReadInputImageTokenCost, src: override.CacheReadInputImageTokenCost}, + {dst: &patched.SearchContextCostPerQuery, src: override.SearchContextCostPerQuery}, + {dst: &patched.CodeInterpreterCostPerSession, src: override.CodeInterpreterCostPerSession}, + {dst: &patched.OutputCostPerImageLowQuality, src: override.OutputCostPerImageLowQuality}, + {dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality}, + {dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality}, + {dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality}, + } { + if field.src != nil { + *field.dst = field.src + } } - if override.OutputCostPerImageHighQuality != nil { - patched.OutputCostPerImageHighQuality = override.OutputCostPerImageHighQuality + return patched +} + +func (mc *ModelCatalog) loadPricingOverridesFromStore(ctx context.Context) error { + if mc.configStore == nil { + return nil } - if override.OutputCostPerImageAutoQuality != nil { - patched.OutputCostPerImageAutoQuality = override.OutputCostPerImageAutoQuality + rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilters{}) + if err != nil { + return err } - - return patched + return mc.SetPricingOverrides(rows) } diff --git a/framework/modelcatalog/overrides_test.go b/framework/modelcatalog/overrides_test.go index 5f2ae1df49..8593aad89a 100644 --- a/framework/modelcatalog/overrides_test.go +++ b/framework/modelcatalog/overrides_test.go @@ -3,6 +3,7 @@ package modelcatalog import ( "testing" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/stretchr/testify/assert" @@ -22,150 +23,182 @@ func (noOpLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuild return schemas.NoopLogEvent } -func TestSetProviderPricingOverrides_InvalidRegex(t *testing.T) { - t.Skip() - mc := newTestCatalog(nil, nil) - err := mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ - { - ModelPattern: "[", - MatchType: schemas.PricingOverrideMatchRegex, - }, - }) - require.Error(t, err) -} - -func TestGetPricing_OverridePrecedenceExactWildcardRegex(t *testing.T) { - t.Skip() +func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) { mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} mc.pricingData[makeKey("gpt-4o", "openai", "chat")] = configstoreTables.TableModelPricing{ Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - exact := 20.0 - wildcard := 10.0 - regex := 30.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ - { - ModelPattern: "gpt-*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &wildcard, - }, + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "^gpt-.*$", - MatchType: schemas.PricingOverrideMatchRegex, - InputCostPerToken: ®ex, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-*", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":10}`, }, { - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - InputCostPerToken: &exact, + ID: "openai-override-1", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":20}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) - assert.Equal(t, 20.0, pricing.InputCostPerToken) - assert.Equal(t, 2.0, pricing.OutputCostPerToken) + require.NotNil(t, pricing.InputCostPerToken) + assert.Equal(t, 20.0, *pricing.InputCostPerToken) } -func TestGetPricing_WildcardBeatsRegex(t *testing.T) { +func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) { t.Skip() mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} - mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{ - Model: "gpt-4o-mini", + mc.pricingData[makeKey("gpt-4o", "openai", "responses")] = configstoreTables.TableModelPricing{ + Model: "gpt-4o", Provider: "openai", - Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + Mode: "responses", + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - wildcard := 11.0 - regex := 12.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "^gpt-4o.*$", - MatchType: schemas.PricingOverrideMatchRegex, - InputCostPerToken: ®ex, + ID: "openai-generic", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":9}`, }, { - ModelPattern: "gpt-4o*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &wildcard, + ID: "openai-specific", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + RequestTypes: []schemas.RequestType{schemas.ResponsesRequest}, + PricingPatchJSON: `{"input_cost_per_token":15}`, }, })) - pricing, ok := mc.getPricing("gpt-4o-mini", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) - assert.Equal(t, 11.0, pricing.InputCostPerToken) + assert.Equal(t, 15.0, pricing.InputCostPerToken) } -func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) { +func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) { t.Skip() mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} - mc.pricingData[makeKey("gpt-4o", "openai", "responses")] = configstoreTables.TableModelPricing{ + mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{ Model: "gpt-4o", - Provider: "openai", - Mode: "responses", - InputCostPerToken: 1, - OutputCostPerToken: 2, + Provider: "vertex", + Mode: "chat", + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - specific := 15.0 - generic := 9.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + geminiProviderID := "gemini" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - InputCostPerToken: &generic, + ID: "gemini-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &geminiProviderID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":7}`, }, + })) + + pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"}) + require.NotNil(t, pricing) + assert.Equal(t, 7.0, pricing.InputCostPerToken) +} + +func TestGetPricing_DeploymentLookupUsesResolvedModelForOverrideMatching(t *testing.T) { + mc := newTestCatalog(nil, nil) + mc.logger = noOpLogger{} + mc.pricingData[makeKey("dep-gpt4o", "openai", "chat")] = configstoreTables.TableModelPricing{ + Model: "dep-gpt4o", + Provider: "openai", + Mode: "chat", + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), + } + + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - RequestTypes: []schemas.RequestType{schemas.ResponsesRequest}, - InputCostPerToken: &specific, + ID: "resolved-model-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "dep-gpt4o", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":7}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ResponsesRequest) - require.True(t, ok) + // Override pattern matches the resolved model name ("dep-gpt4o"), not the + // originally requested name ("gpt-4o"), because resolved model has priority. + pricing := mc.resolvePricing("openai", "gpt-4o", "dep-gpt4o", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) - assert.Equal(t, 15.0, pricing.InputCostPerToken) + require.NotNil(t, pricing.InputCostPerToken) + assert.Equal(t, 7.0, *pricing.InputCostPerToken) } -func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) { - t.Skip() +func TestGetPricing_FallbackUsesRequestedProviderForScopeMatching(t *testing.T) { mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} mc.pricingData[makeKey("gpt-4o", "vertex", "chat")] = configstoreTables.TableModelPricing{ Model: "gpt-4o", Provider: "vertex", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - override := 7.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.Gemini, []schemas.ProviderPricingOverride{ + geminiProviderID := "gemini" + vertexProviderID := "vertex" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ + { + ID: "gemini-provider-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &geminiProviderID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":5}`, + }, { - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - InputCostPerToken: &override, + ID: "vertex-provider-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &vertexProviderID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":9}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "gemini", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"}) require.NotNil(t, pricing) - assert.Equal(t, 7.0, pricing.InputCostPerToken) + require.NotNil(t, pricing.InputCostPerToken) + assert.Equal(t, 5.0, *pricing.InputCostPerToken) } func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) { @@ -176,21 +209,23 @@ func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) Model: "openai/gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - override := 19.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - InputCostPerToken: &override, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":19}`, }, })) - pricing, ok := mc.getPricing("openai/gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "openai/gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 1.0, pricing.InputCostPerToken) } @@ -204,22 +239,24 @@ func TestGetPricing_NoMatchingOverrideLeavesPricingUnchanged(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), CacheReadInputTokenCost: &baseCacheRead, } - override := 9.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "claude-*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &override, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "claude-*", + PricingPatchJSON: `{"input_cost_per_token":9}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 1.0, pricing.InputCostPerToken) assert.Equal(t, 2.0, pricing.OutputCostPerToken) @@ -235,28 +272,29 @@ func TestDeleteProviderPricingOverrides_StopsApplying(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - override := 11.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - InputCostPerToken: &override, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":11}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 11.0, pricing.InputCostPerToken) - mc.DeleteProviderPricingOverrides(schemas.OpenAI) + require.NoError(t, mc.SetPricingOverrides(nil)) - pricing, ok = mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing = mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 1.0, pricing.InputCostPerToken) } @@ -269,62 +307,74 @@ func TestGetPricing_WildcardSpecificityLongerLiteralWins(t *testing.T) { Model: "gpt-4o-mini", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - generic := 5.0 - specific := 6.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &generic, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-*", + PricingPatchJSON: `{"input_cost_per_token":5}`, }, { - ModelPattern: "gpt-4o*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &specific, + ID: "openai-override-1", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-4o*", + PricingPatchJSON: `{"input_cost_per_token":6}`, }, })) - pricing, ok := mc.getPricing("gpt-4o-mini", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 6.0, pricing.InputCostPerToken) } -func TestGetPricing_ConfigOrderTiebreakFirstWinsWhenEqual(t *testing.T) { - t.Skip() +// TestGetPricing_FirstInsertionWinsOnTie verifies that when multiple wildcard overrides +// match the same model and scope, the first one inserted takes precedence. +func TestGetPricing_FirstInsertionWinsOnTie(t *testing.T) { mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{ Model: "gpt-4o-mini", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), } - first := 8.0 - second := 9.0 - require.NoError(t, mc.SetProviderPricingOverrides(schemas.OpenAI, []schemas.ProviderPricingOverride{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &first, + ID: "a-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-4o*", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":8}`, }, { - ModelPattern: "gpt-4o*", - MatchType: schemas.PricingOverrideMatchWildcard, - InputCostPerToken: &second, + ID: "b-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-4o*", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":9}`, }, })) - pricing, ok := mc.getPricing("gpt-4o-mini", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) - assert.Equal(t, 8.0, pricing.InputCostPerToken) + require.NotNil(t, pricing.InputCostPerToken) + assert.Equal(t, 8.0, *pricing.InputCostPerToken) } func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) { @@ -335,26 +385,123 @@ func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), CacheReadInputTokenCost: &baseCacheRead, InputCostPerImage: &baseInputImage, } - patched := patchPricing(base, schemas.ProviderPricingOverride{ - ModelPattern: "gpt-4o", - MatchType: schemas.PricingOverrideMatchExact, - InputCostPerToken: schemas.Ptr(3.0), - CacheReadInputTokenCost: schemas.Ptr(0.9), + cacheRead := 0.9 + patched := patchPricing(base, PricingOptions{ + InputCostPerToken: bifrost.Ptr(3.0), + CacheReadInputTokenCost: &cacheRead, }) - // Changed fields assert.Equal(t, 3.0, patched.InputCostPerToken) require.NotNil(t, patched.CacheReadInputTokenCost) assert.Equal(t, 0.9, *patched.CacheReadInputTokenCost) - // Unchanged fields assert.Equal(t, 2.0, patched.OutputCostPerToken) require.NotNil(t, patched.InputCostPerImage) assert.Equal(t, 0.7, *patched.InputCostPerImage) } + +func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) { + mc := newTestCatalog(nil, nil) + mc.logger = noOpLogger{} + + providerScopeID := "openai" + providerKeyScopeID := "provider-key-1" + virtualKeyScopeID := "virtual-key-1" + + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ + { + ID: "global", + ScopeKind: string(ScopeKindGlobal), + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":2}`, + }, + { + ID: "provider", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerScopeID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":3}`, + }, + { + ID: "provider-key", + ScopeKind: string(ScopeKindProviderKey), + ProviderKeyID: &providerKeyScopeID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":4}`, + }, + { + ID: "virtual-key", + ScopeKind: string(ScopeKindVirtualKey), + VirtualKeyID: &virtualKeyScopeID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + PricingPatchJSON: `{"input_cost_per_token":5}`, + }, + })) + + base := configstoreTables.TableModelPricing{ + Model: "gpt-5-nano", + Provider: "openai", + Mode: "chat", + InputCostPerToken: bifrost.Ptr(1.0), + OutputCostPerToken: bifrost.Ptr(2.0), + } + + tests := []struct { + name string + scopes PricingLookupScopes + expected float64 + }{ + { + name: "virtual key wins over provider key, provider and global", + scopes: PricingLookupScopes{ + VirtualKeyID: virtualKeyScopeID, + SelectedKeyID: providerKeyScopeID, + Provider: providerScopeID, + }, + expected: 5.0, + }, + { + name: "provider key wins over provider and global", + scopes: PricingLookupScopes{ + SelectedKeyID: providerKeyScopeID, + Provider: providerScopeID, + }, + expected: 4.0, + }, + { + name: "provider wins over global", + scopes: PricingLookupScopes{ + Provider: providerScopeID, + }, + expected: 3.0, + }, + { + name: "global applies when no narrower scope is provided", + scopes: PricingLookupScopes{}, + expected: 2.0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + patched, applied := mc.applyPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes) + require.True(t, applied) + require.NotNil(t, patched.InputCostPerToken) + assert.Equal(t, tc.expected, *patched.InputCostPerToken) + }) + } +} diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index decb3e78ea..6942a3ad05 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -23,22 +23,29 @@ type costInput struct { // CalculateCost calculates the cost of a Bifrost response. // It handles all request types, cache debug billing, and tiered pricing. -func (mc *ModelCatalog) CalculateCost(result *schemas.BifrostResponse) float64 { +// If scopes is nil, an empty PricingLookupScopes is used; global and provider-scoped +// overrides may still apply since the provider is derived from the response. +func (mc *ModelCatalog) CalculateCost(result *schemas.BifrostResponse, scopes *PricingLookupScopes) float64 { if result == nil { return 0 } + var s PricingLookupScopes + if scopes != nil { + s = *scopes + } + // Handle semantic cache billing cacheDebug := result.GetExtraFields().CacheDebug if cacheDebug != nil { - return mc.calculateCostWithCache(result, cacheDebug) + return mc.calculateCostWithCache(result, cacheDebug, s) } - return mc.calculateBaseCost(result) + return mc.calculateBaseCost(result, s) } // calculateCostWithCache handles cost calculation when semantic cache debug info is present. -func (mc *ModelCatalog) calculateCostWithCache(result *schemas.BifrostResponse, cacheDebug *schemas.BifrostCacheDebug) float64 { +func (mc *ModelCatalog) calculateCostWithCache(result *schemas.BifrostResponse, cacheDebug *schemas.BifrostCacheDebug, scopes PricingLookupScopes) float64 { if cacheDebug.CacheHit { // Direct cache hit β€” no LLM call, no cost if cacheDebug.HitType != nil && *cacheDebug.HitType == "direct" { @@ -46,39 +53,42 @@ func (mc *ModelCatalog) calculateCostWithCache(result *schemas.BifrostResponse, } // Semantic cache hit β€” only the embedding lookup cost if cacheDebug.ProviderUsed != nil && cacheDebug.ModelUsed != nil && cacheDebug.InputTokens != nil { - return mc.computeCacheEmbeddingCost(cacheDebug) + return mc.computeCacheEmbeddingCost(cacheDebug, scopes) } return 0 } // Cache miss β€” full LLM cost + embedding lookup cost - baseCost := mc.calculateBaseCost(result) - embeddingCost := mc.computeCacheEmbeddingCost(cacheDebug) + baseCost := mc.calculateBaseCost(result, scopes) + embeddingCost := mc.computeCacheEmbeddingCost(cacheDebug, scopes) return baseCost + embeddingCost } // computeCacheEmbeddingCost calculates the embedding cost for a semantic cache lookup. -func (mc *ModelCatalog) computeCacheEmbeddingCost(cacheDebug *schemas.BifrostCacheDebug) float64 { +func (mc *ModelCatalog) computeCacheEmbeddingCost(cacheDebug *schemas.BifrostCacheDebug, scopes PricingLookupScopes) float64 { if cacheDebug == nil || cacheDebug.ProviderUsed == nil || cacheDebug.ModelUsed == nil || cacheDebug.InputTokens == nil { return 0 } - pricing, exists := mc.getPricing(*cacheDebug.ModelUsed, *cacheDebug.ProviderUsed, schemas.EmbeddingRequest) - if !exists { + if scopes.Provider == "" { + scopes.Provider = *cacheDebug.ProviderUsed + } + pricing := mc.resolvePricing(*cacheDebug.ProviderUsed, *cacheDebug.ModelUsed, "", schemas.EmbeddingRequest, scopes) + if pricing == nil { return 0 } return float64(*cacheDebug.InputTokens) * tieredInputRate(pricing, *cacheDebug.InputTokens) } // calculateBaseCost extracts usage from the response and routes to the appropriate compute function. -func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse) float64 { +func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scopes PricingLookupScopes) float64 { extraFields := result.GetExtraFields() if extraFields == nil { return 0 } provider := string(extraFields.Provider) - model := extraFields.ModelRequested - deployment := extraFields.ModelDeployment + originalModelRequested := extraFields.OriginalModelRequested + resolvedModelUsed := extraFields.ResolvedModelUsed requestType := extraFields.RequestType // Extract usage data from the response @@ -98,14 +108,14 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse) float requestType = normalizeStreamRequestType(requestType) // Resolve pricing entry with deployment fallback - pricing := mc.resolvePricing(provider, model, deployment, requestType) + pricing := mc.resolvePricing(provider, originalModelRequested, resolvedModelUsed, requestType, scopes) if pricing == nil { return 0 } // Route to the appropriate compute function switch requestType { - case schemas.ChatCompletionRequest, schemas.TextCompletionRequest, schemas.ResponsesRequest: + case schemas.ChatCompletionRequest, schemas.TextCompletionRequest, schemas.ResponsesRequest, schemas.RealtimeRequest: return computeTextCost(pricing, input.usage) case schemas.EmbeddingRequest: return computeEmbeddingCost(pricing, input.usage) @@ -598,7 +608,10 @@ func tieredInputRate(pricing *configstoreTables.TableModelPricing, totalTokens i if totalTokens > TokenTierAbove128K && pricing.InputCostPerTokenAbove128kTokens != nil { return *pricing.InputCostPerTokenAbove128kTokens } - return pricing.InputCostPerToken + if pricing.InputCostPerToken != nil { + return *pricing.InputCostPerToken + } + return 0 } // tieredOutputRate returns the effective per-token output rate based on total token count. @@ -609,7 +622,10 @@ func tieredOutputRate(pricing *configstoreTables.TableModelPricing, totalTokens if totalTokens > TokenTierAbove128K && pricing.OutputCostPerTokenAbove128kTokens != nil { return *pricing.OutputCostPerTokenAbove128kTokens } - return pricing.OutputCostPerToken + if pricing.OutputCostPerToken != nil { + return *pricing.OutputCostPerToken + } + return 0 } // tieredImageInputRate returns the effective rate for image tokens on the input side. @@ -743,28 +759,61 @@ func populateOutputImageCount(imageUsage *schemas.ImageUsage, dataLen int) { // --------------------------------------------------------------------------- // resolvePricing resolves the pricing entry for a model, trying deployment as fallback. -func (mc *ModelCatalog) resolvePricing(provider, model, deployment string, requestType schemas.RequestType) *configstoreTables.TableModelPricing { - mc.logger.Debug("looking up pricing for model %s and provider %s of request type %s", model, provider, normalizeRequestType(requestType)) +func (mc *ModelCatalog) resolvePricing(provider, originalModelRequested, resolvedModelUsed string, requestType schemas.RequestType, scopes PricingLookupScopes) *configstoreTables.TableModelPricing { + if resolvedModelUsed == "" { + resolvedModelUsed = originalModelRequested + } + mc.logger.Debug("looking up pricing for resolved model %s and provider %s of request type %s", resolvedModelUsed, provider, normalizeRequestType(requestType)) - pricing, exists := mc.getPricing(model, provider, requestType) - if exists { - return pricing + if scopes.Provider == "" { + scopes.Provider = provider } - if deployment != "" { - mc.logger.Debug("pricing not found for model %s, trying deployment %s", model, deployment) - pricing, exists = mc.getPricing(deployment, provider, requestType) - if exists { - return pricing - } + base, exists := mc.getBasePricing(resolvedModelUsed, provider, requestType) + if exists && base != nil { + result, _ := mc.applyPricingOverrides(resolvedModelUsed, requestType, *base, scopes) + return &result } - mc.logger.Debug("pricing not found for model %s and provider %s, skipping cost calculation", model, provider) + mc.logger.Debug("pricing not found for resolved model %s, trying alias %s", resolvedModelUsed, originalModelRequested) + base, exists = mc.getBasePricing(originalModelRequested, provider, requestType) + if exists && base != nil { + // Apply overrides using the resolved model name, not the alias + result, _ := mc.applyPricingOverrides(resolvedModelUsed, requestType, *base, scopes) + return &result + } + + // No base catalog entry found; still try overrides in case the user defined + // override-only pricing for a model not in the built-in catalog. + mc.logger.Debug("pricing not found for resolved model %s and provider %s, trying override-only pricing", resolvedModelUsed, provider) + result, applied := mc.applyPricingOverrides(resolvedModelUsed, requestType, configstoreTables.TableModelPricing{}, scopes) + if applied { + return &result + } + mc.logger.Debug("no pricing found for resolved model %s and provider %s, skipping cost calculation", resolvedModelUsed, provider) return nil } -// getPricing returns pricing information for a model (thread-safe) -func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.RequestType) (*configstoreTables.TableModelPricing, bool) { +// getBasePricing looks up catalog pricing for the given model, provider, and request type. +// It applies a provider-specific fallback chain when an exact match is not found: +// +// - Gemini: retries under the "vertex" provider, then falls back to chat mode for Responses requests. +// - Vertex: strips the "provider/model" prefix and retries, then falls back to chat mode for Responses requests. +// - Bedrock: prepends the "anthropic." namespace for Claude models, then falls back to chat mode for Responses requests. +// - All providers: for Responses/ResponsesStream requests, retries the lookup in chat mode. +// - All providers: for ImageEdit/ImageVariation requests, retries the lookup in image-generation mode. +// +// The method acquires a read lock for the duration of the lookup. +// +// Input: model β€” exact model name to look up. +// +// provider β€” provider identifier (e.g. "openai", "anthropic"). +// requestType β€” the request type used to derive the pricing mode. +// +// Output: TableModelPricing β€” the matched pricing row (zero value when not found). +// +// bool β€” true when a pricing entry was found, false otherwise. +func (mc *ModelCatalog) getBasePricing(model, provider string, requestType schemas.RequestType) (*configstoreTables.TableModelPricing, bool) { mc.mu.RLock() defer mc.mu.RUnlock() @@ -784,7 +833,7 @@ func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.R } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("secondary lookup failed, trying vertex provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey(model, "vertex", normalizeRequestType(schemas.ChatCompletionRequest))] if ok { @@ -804,7 +853,7 @@ func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.R } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("secondary lookup failed, trying vertex provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey(modelWithoutProvider, "vertex", normalizeRequestType(schemas.ChatCompletionRequest))] if ok { @@ -824,7 +873,7 @@ func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.R } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("secondary lookup failed, trying chat provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey("anthropic."+model, provider, normalizeRequestType(schemas.ChatCompletionRequest))] if ok { @@ -835,7 +884,7 @@ func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.R } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("primary lookup failed, trying chat provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey(model, provider, normalizeRequestType(schemas.ChatCompletionRequest))] if ok { diff --git a/framework/modelcatalog/pricing_test.go b/framework/modelcatalog/pricing_test.go index 1433e0035f..8a4c176092 100644 --- a/framework/modelcatalog/pricing_test.go +++ b/framework/modelcatalog/pricing_test.go @@ -3,6 +3,7 @@ package modelcatalog import ( "testing" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/stretchr/testify/assert" @@ -13,17 +14,14 @@ import ( // helpers // --------------------------------------------------------------------------- -func ptr(v float64) *float64 { return &v } -func intPtr(v int) *int { return &v } - // chatPricing returns a TableModelPricing with the given per-token rates. func chatPricing(input, output float64) configstoreTables.TableModelPricing { return configstoreTables.TableModelPricing{ Model: "test-model", Provider: "test-provider", Mode: "chat", - InputCostPerToken: input, - OutputCostPerToken: output, + InputCostPerToken: bifrost.Ptr(input), + OutputCostPerToken: bifrost.Ptr(output), } } @@ -43,9 +41,9 @@ func makeChatResponse(provider schemas.ModelProvider, model string, usage *schem ChatResponse: &schemas.BifrostChatResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.ChatCompletionRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -57,9 +55,9 @@ func makeEmbeddingResponse(provider schemas.ModelProvider, model string, usage * EmbeddingResponse: &schemas.BifrostEmbeddingResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.EmbeddingRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -71,9 +69,9 @@ func makeRerankResponse(provider schemas.ModelProvider, model string, usage *sch RerankResponse: &schemas.BifrostRerankResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.RerankRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.RerankRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } @@ -85,14 +83,21 @@ func makeImageResponse(provider schemas.ModelProvider, model string, usage *sche ImageGenerationResponse: &schemas.BifrostImageGenerationResponse{ Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: provider, - ModelRequested: model, + RequestType: schemas.ImageGenerationRequest, + Provider: provider, + OriginalModelRequested: model, }, }, } } +func derefF(f *float64) float64 { + if f == nil { + return 0 + } + return *f +} + // ========================================================================= // 1. computeTextCost β€” unit tests (pure function, no catalog) // ========================================================================= @@ -124,8 +129,8 @@ func TestComputeTextCost_ZeroTokens(t *testing.T) { func TestComputeTextCost_WithCachedPromptTokens(t *testing.T) { // Claude 3.5 Sonnet (Bedrock): input=$3/M, output=$15/M, cache_read=$0.3/M, cache_creation=$3.75/M p := chatPricing(0.000003, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000003) - p.CacheCreationInputTokenCost = ptr(0.00000375) + p.CacheReadInputTokenCost = bifrost.Ptr(0.0000003) + p.CacheCreationInputTokenCost = bifrost.Ptr(0.00000375) usage := &schemas.BifrostLLMUsage{ PromptTokens: 2000, @@ -149,8 +154,8 @@ func TestComputeTextCost_WithCachedPromptTokens(t *testing.T) { func TestComputeTextCost_Tiered200k(t *testing.T) { // Claude 3.5 Sonnet Bedrock 200k tier: input=$6/M, output=$30/M p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove200kTokens = ptr(0.000006) - p.OutputCostPerTokenAbove200kTokens = ptr(0.00003) + p.InputCostPerTokenAbove200kTokens = bifrost.Ptr(0.000006) + p.OutputCostPerTokenAbove200kTokens = bifrost.Ptr(0.00003) usage := &schemas.BifrostLLMUsage{ PromptTokens: 180000, @@ -167,8 +172,8 @@ func TestComputeTextCost_Tiered200k(t *testing.T) { func TestComputeTextCost_Below200kUsesBaseRate(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.InputCostPerTokenAbove200kTokens = ptr(0.000006) - p.OutputCostPerTokenAbove200kTokens = ptr(0.00003) + p.InputCostPerTokenAbove200kTokens = bifrost.Ptr(0.000006) + p.OutputCostPerTokenAbove200kTokens = bifrost.Ptr(0.00003) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -185,7 +190,7 @@ func TestComputeTextCost_Below200kUsesBaseRate(t *testing.T) { func TestComputeTextCost_SearchQueryCost(t *testing.T) { p := chatPricing(0.000003, 0.000015) - p.SearchContextCostPerQuery = ptr(0.01) // $0.01 per search query + p.SearchContextCostPerQuery = bifrost.Ptr(0.01) // $0.01 per search query numQueries := 3 usage := &schemas.BifrostLLMUsage{ @@ -232,8 +237,8 @@ func TestComputeTextCost_NoCacheRateFallsBackToBaseInputRate(t *testing.T) { func TestComputeEmbeddingCost_Basic(t *testing.T) { // Titan Embed Text v1: $0.1/M input p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.0000001, - OutputCostPerToken: 0, + InputCostPerToken: bifrost.Ptr(0.0000001), + OutputCostPerToken: bifrost.Ptr(0.0), } usage := &schemas.BifrostLLMUsage{ PromptTokens: 5000, @@ -245,7 +250,7 @@ func TestComputeEmbeddingCost_Basic(t *testing.T) { } func TestComputeEmbeddingCost_NilUsage(t *testing.T) { - p := configstoreTables.TableModelPricing{InputCostPerToken: 0.0000001} + p := configstoreTables.TableModelPricing{InputCostPerToken: bifrost.Ptr(0.0000001)} assert.Equal(t, 0.0, computeEmbeddingCost(&p, nil)) } @@ -255,8 +260,8 @@ func TestComputeEmbeddingCost_NilUsage(t *testing.T) { func TestComputeRerankCost_Basic(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000001, - OutputCostPerToken: 0.000002, + InputCostPerToken: bifrost.Ptr(0.000001), + OutputCostPerToken: bifrost.Ptr(0.000002), } usage := &schemas.BifrostLLMUsage{ PromptTokens: 2000, @@ -270,9 +275,9 @@ func TestComputeRerankCost_Basic(t *testing.T) { func TestComputeRerankCost_WithSearchCost(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0, - OutputCostPerToken: 0, - SearchContextCostPerQuery: ptr(0.001), + InputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerToken: bifrost.Ptr(0.0), + SearchContextCostPerQuery: bifrost.Ptr(0.001), } numQueries := 5 usage := &schemas.BifrostLLMUsage{ @@ -285,7 +290,7 @@ func TestComputeRerankCost_WithSearchCost(t *testing.T) { } func TestComputeRerankCost_NilUsage(t *testing.T) { - p := configstoreTables.TableModelPricing{InputCostPerToken: 0.001} + p := configstoreTables.TableModelPricing{InputCostPerToken: bifrost.Ptr(0.001)} assert.Equal(t, 0.0, computeRerankCost(&p, nil)) } @@ -296,9 +301,9 @@ func TestComputeRerankCost_NilUsage(t *testing.T) { func TestComputeSpeechCost_TokensPreferredOverDuration(t *testing.T) { // TTS: input=text tokens, output=audio tokens (preferred over per-second) p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.0000025, - OutputCostPerToken: 0.00001, - OutputCostPerSecond: ptr(0.00025), + InputCostPerToken: bifrost.Ptr(0.0000025), + OutputCostPerToken: bifrost.Ptr(0.00001), + OutputCostPerSecond: bifrost.Ptr(0.00025), } seconds := 60 usage := &schemas.BifrostLLMUsage{ @@ -317,9 +322,9 @@ func TestComputeSpeechCost_TokensPreferredOverDuration(t *testing.T) { func TestComputeSpeechCost_OutputFallsBackToPerSecond(t *testing.T) { // TTS: no output tokens β†’ falls back to per-second output pricing p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000001, - OutputCostPerToken: 0.000002, - OutputCostPerSecond: ptr(0.0001), + InputCostPerToken: bifrost.Ptr(0.000001), + OutputCostPerToken: bifrost.Ptr(0.000002), + OutputCostPerSecond: bifrost.Ptr(0.0001), } seconds := 120 usage := &schemas.BifrostLLMUsage{PromptTokens: 500} @@ -333,9 +338,9 @@ func TestComputeSpeechCost_OutputFallsBackToPerSecond(t *testing.T) { func TestComputeSpeechCost_OutputAudioTokenRate(t *testing.T) { // TTS: output uses OutputCostPerAudioToken when available p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000001, - OutputCostPerToken: 0.000002, - OutputCostPerAudioToken: ptr(0.00005), + InputCostPerToken: bifrost.Ptr(0.000001), + OutputCostPerToken: bifrost.Ptr(0.000002), + OutputCostPerAudioToken: bifrost.Ptr(0.00005), } usage := &schemas.BifrostLLMUsage{ PromptTokens: 200, @@ -373,9 +378,9 @@ func TestComputeSpeechCost_NilUsageNilSeconds(t *testing.T) { func TestComputeTranscriptionCost_DurationBased(t *testing.T) { // assemblyai/nano: input_cost_per_second=0.00010278 p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0, - OutputCostPerToken: 0, - InputCostPerSecond: ptr(0.00010278), + InputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerToken: bifrost.Ptr(0.0), + InputCostPerSecond: bifrost.Ptr(0.00010278), } seconds := 300 // 5 minutes cost := computeTranscriptionCost(&p, nil, &seconds, nil) @@ -385,9 +390,9 @@ func TestComputeTranscriptionCost_DurationBased(t *testing.T) { func TestComputeTranscriptionCost_AudioTokenDetails(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerAudioToken: ptr(0.00001), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + InputCostPerAudioToken: bifrost.Ptr(0.00001), } usage := &schemas.BifrostLLMUsage{ PromptTokens: 2000, @@ -421,10 +426,10 @@ func TestComputeTranscriptionCost_TokenFallback(t *testing.T) { func TestComputeTranscriptionCost_TokenDetailsPreferredOverDuration(t *testing.T) { // STT: audio token details present β†’ uses tokens, not per-second p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0, - InputCostPerAudioPerSecond: ptr(0.0001), - InputCostPerAudioToken: ptr(0.00001), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.0), + InputCostPerAudioPerSecond: bifrost.Ptr(0.0001), + InputCostPerAudioToken: bifrost.Ptr(0.00001), } seconds := 60 audioDetails := &schemas.TranscriptionUsageInputTokenDetails{ @@ -443,9 +448,9 @@ func TestComputeTranscriptionCost_TokenDetailsPreferredOverDuration(t *testing.T func TestComputeTranscriptionCost_DurationFallbackWhenNoTokens(t *testing.T) { // STT: no audio token details, no prompt tokens β†’ falls back to per-second p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerAudioPerSecond: ptr(0.0001), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + InputCostPerAudioPerSecond: bifrost.Ptr(0.0001), } seconds := 60 usage := &schemas.BifrostLLMUsage{ @@ -466,9 +471,9 @@ func TestComputeTranscriptionCost_DurationFallbackWhenNoTokens(t *testing.T) { func TestComputeImageCost_PerImage(t *testing.T) { // dall-e-3 (aiml): output_cost_per_image=$0.052 p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0, - OutputCostPerToken: 0, - OutputCostPerImage: ptr(0.052), + InputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerImage: bifrost.Ptr(0.052), } usage := &schemas.ImageUsage{ OutputTokensDetails: &schemas.ImageTokenDetails{ @@ -482,7 +487,7 @@ func TestComputeImageCost_PerImage(t *testing.T) { func TestComputeImageCost_PerImageDefaultsToOne(t *testing.T) { p := configstoreTables.TableModelPricing{ - OutputCostPerImage: ptr(0.052), + OutputCostPerImage: bifrost.Ptr(0.052), } usage := &schemas.ImageUsage{} // No token details β†’ defaults to 1 image cost := computeImageCost(&p, usage, "", "") @@ -491,8 +496,8 @@ func TestComputeImageCost_PerImageDefaultsToOne(t *testing.T) { func TestComputeImageCost_TokenBased(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), } usage := &schemas.ImageUsage{ InputTokens: 1000, @@ -506,8 +511,8 @@ func TestComputeImageCost_TokenBased(t *testing.T) { func TestComputeImageCost_TokenBasedWithDetails(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), } usage := &schemas.ImageUsage{ InputTokens: 2000, @@ -530,14 +535,14 @@ func TestComputeImageCost_TokenBasedWithDetails(t *testing.T) { } func TestComputeImageCost_NilUsage(t *testing.T) { - p := configstoreTables.TableModelPricing{OutputCostPerImage: ptr(0.05)} + p := configstoreTables.TableModelPricing{OutputCostPerImage: bifrost.Ptr(0.05)} assert.Equal(t, 0.0, computeImageCost(&p, nil, "", "")) } func TestComputeImageCost_InputAndOutputPerImage(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerImage: ptr(0.01), - OutputCostPerImage: ptr(0.05), + InputCostPerImage: bifrost.Ptr(0.01), + OutputCostPerImage: bifrost.Ptr(0.05), } usage := &schemas.ImageUsage{ NumInputImages: 3, @@ -550,7 +555,7 @@ func TestComputeImageCost_InputAndOutputPerImage(t *testing.T) { func TestComputeImageCost_PerPixelOutput(t *testing.T) { p := configstoreTables.TableModelPricing{ - OutputCostPerPixel: ptr(0.000000019), // ~$0.02 for 1024x1024 + OutputCostPerPixel: bifrost.Ptr(0.000000019), // ~$0.02 for 1024x1024 } usage := &schemas.ImageUsage{ OutputTokensDetails: &schemas.ImageTokenDetails{NImages: 1}, @@ -562,8 +567,8 @@ func TestComputeImageCost_PerPixelOutput(t *testing.T) { func TestComputeImageCost_PerPixelInputAndOutput(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerPixel: ptr(0.00000001), - OutputCostPerPixel: ptr(0.00000002), + InputCostPerPixel: bifrost.Ptr(0.00000001), + OutputCostPerPixel: bifrost.Ptr(0.00000002), } usage := &schemas.ImageUsage{ NumInputImages: 2, @@ -579,10 +584,10 @@ func TestComputeImageCost_PerPixelInputAndOutput(t *testing.T) { func TestComputeImageCost_TokensPreferredOverPixels(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerPixel: ptr(0.00000001), - OutputCostPerPixel: ptr(0.00000002), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + InputCostPerPixel: bifrost.Ptr(0.00000001), + OutputCostPerPixel: bifrost.Ptr(0.00000002), } usage := &schemas.ImageUsage{ InputTokens: 1000, @@ -596,8 +601,8 @@ func TestComputeImageCost_TokensPreferredOverPixels(t *testing.T) { func TestComputeImageCost_PixelsPreferredOverPerImage(t *testing.T) { p := configstoreTables.TableModelPricing{ - OutputCostPerPixel: ptr(0.00000002), - OutputCostPerImage: ptr(999.0), // should not be used + OutputCostPerPixel: bifrost.Ptr(0.00000002), + OutputCostPerImage: bifrost.Ptr(999.0), // should not be used } usage := &schemas.ImageUsage{ OutputTokensDetails: &schemas.ImageTokenDetails{NImages: 1}, @@ -609,8 +614,8 @@ func TestComputeImageCost_PixelsPreferredOverPerImage(t *testing.T) { func TestComputeImageCost_PerPixelFallsBackToPerImage_WhenNoSize(t *testing.T) { p := configstoreTables.TableModelPricing{ - OutputCostPerPixel: ptr(0.00000002), - OutputCostPerImage: ptr(0.05), + OutputCostPerPixel: bifrost.Ptr(0.00000002), + OutputCostPerImage: bifrost.Ptr(0.05), } usage := &schemas.ImageUsage{ OutputTokensDetails: &schemas.ImageTokenDetails{NImages: 2}, @@ -626,11 +631,11 @@ func TestComputeImageCost_QualityBasedRates(t *testing.T) { } // Quality-specific rates take precedence over base/size-tier p := configstoreTables.TableModelPricing{ - OutputCostPerImage: ptr(0.01), - OutputCostPerImageLowQuality: ptr(0.02), - OutputCostPerImageMediumQuality: ptr(0.03), - OutputCostPerImageHighQuality: ptr(0.04), - OutputCostPerImageAutoQuality: ptr(0.05), + OutputCostPerImage: bifrost.Ptr(0.01), + OutputCostPerImageLowQuality: bifrost.Ptr(0.02), + OutputCostPerImageMediumQuality: bifrost.Ptr(0.03), + OutputCostPerImageHighQuality: bifrost.Ptr(0.04), + OutputCostPerImageAutoQuality: bifrost.Ptr(0.05), } assert.InDelta(t, 0.02, computeImageCost(&p, usage, "", "low"), 1e-12) assert.InDelta(t, 0.03, computeImageCost(&p, usage, "", "medium"), 1e-12) @@ -659,9 +664,9 @@ func TestParseImagePixels(t *testing.T) { func TestComputeVideoCost_DurationBased(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000001, - OutputCostPerToken: 0, - OutputCostPerVideoPerSecond: ptr(0.001), + InputCostPerToken: bifrost.Ptr(0.000001), + OutputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerVideoPerSecond: bifrost.Ptr(0.001), } seconds := 30 usage := &schemas.BifrostLLMUsage{PromptTokens: 500, TotalTokens: 500} @@ -674,9 +679,9 @@ func TestComputeVideoCost_DurationBased(t *testing.T) { func TestComputeVideoCost_OutputCostPerSecondFallback(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0, - OutputCostPerToken: 0, - OutputCostPerSecond: ptr(0.002), + InputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerSecond: bifrost.Ptr(0.002), } seconds := 10 cost := computeVideoCost(&p, nil, &seconds) @@ -685,8 +690,8 @@ func TestComputeVideoCost_OutputCostPerSecondFallback(t *testing.T) { func TestComputeVideoCost_NilSeconds(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000001, - OutputCostPerVideoPerSecond: ptr(0.001), + InputCostPerToken: bifrost.Ptr(0.000001), + OutputCostPerVideoPerSecond: bifrost.Ptr(0.001), } usage := &schemas.BifrostLLMUsage{PromptTokens: 1000} cost := computeVideoCost(&p, usage, nil) @@ -700,23 +705,23 @@ func TestComputeVideoCost_NilSeconds(t *testing.T) { func TestTieredInputRate_BelowThreshold(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000003, - InputCostPerTokenAbove200kTokens: ptr(0.000006), + InputCostPerToken: bifrost.Ptr(0.000003), + InputCostPerTokenAbove200kTokens: bifrost.Ptr(0.000006), } assert.Equal(t, 0.000003, tieredInputRate(&p, 100000)) } func TestTieredInputRate_AboveThreshold(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000003, - InputCostPerTokenAbove200kTokens: ptr(0.000006), + InputCostPerToken: bifrost.Ptr(0.000003), + InputCostPerTokenAbove200kTokens: bifrost.Ptr(0.000006), } assert.Equal(t, 0.000006, tieredInputRate(&p, 210000)) } func TestTieredInputRate_AboveThresholdNoTieredRate(t *testing.T) { p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000003, + InputCostPerToken: bifrost.Ptr(0.000003), } // Falls back to base rate when tiered field is nil assert.Equal(t, 0.000003, tieredInputRate(&p, 300000)) @@ -724,8 +729,8 @@ func TestTieredInputRate_AboveThresholdNoTieredRate(t *testing.T) { func TestTieredOutputRate_AboveThreshold(t *testing.T) { p := configstoreTables.TableModelPricing{ - OutputCostPerToken: 0.000015, - OutputCostPerTokenAbove200kTokens: ptr(0.00003), + OutputCostPerToken: bifrost.Ptr(0.000015), + OutputCostPerTokenAbove200kTokens: bifrost.Ptr(0.00003), } assert.Equal(t, 0.00003, tieredOutputRate(&p, 250000)) } @@ -772,9 +777,9 @@ func TestExtractCostInput_TranscriptionWithSeconds(t *testing.T) { TranscriptionResponse: &schemas.BifrostTranscriptionResponse{ Usage: &schemas.TranscriptionUsage{ Seconds: &sec, - InputTokens: intPtr(1000), - OutputTokens: intPtr(200), - TotalTokens: intPtr(1200), + InputTokens: bifrost.Ptr(1000), + OutputTokens: bifrost.Ptr(200), + TotalTokens: bifrost.Ptr(1200), }, }, } @@ -833,7 +838,7 @@ func TestCalculateCost_SemanticCacheDirectHit(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, OutputCostPerToken: 0.000015, + InputCostPerToken: bifrost.Ptr(0.000005), OutputCostPerToken: bifrost.Ptr(0.000015), }, }) @@ -842,9 +847,9 @@ func TestCalculateCost_SemanticCacheDirectHit(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", CacheDebug: &schemas.BifrostCacheDebug{ CacheHit: true, HitType: &hitType, @@ -853,7 +858,7 @@ func TestCalculateCost_SemanticCacheDirectHit(t *testing.T) { }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.Equal(t, 0.0, cost) } @@ -865,11 +870,11 @@ func TestCalculateCost_SemanticCacheSemanticHit(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, OutputCostPerToken: 0.000015, + InputCostPerToken: bifrost.Ptr(0.000005), OutputCostPerToken: bifrost.Ptr(0.000015), }, makeKey("text-embedding-3-small", "openai", "embedding"): { Model: "text-embedding-3-small", Provider: "openai", Mode: "embedding", - InputCostPerToken: 0.00000002, + InputCostPerToken: bifrost.Ptr(0.00000002), }, }) @@ -878,9 +883,9 @@ func TestCalculateCost_SemanticCacheSemanticHit(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", CacheDebug: &schemas.BifrostCacheDebug{ CacheHit: true, HitType: &hitType, @@ -892,7 +897,7 @@ func TestCalculateCost_SemanticCacheSemanticHit(t *testing.T) { }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Only embedding cost: 500 * 0.00000002 = 0.00001 assert.InDelta(t, 0.00001, cost, 1e-12) } @@ -905,11 +910,11 @@ func TestCalculateCost_SemanticCacheMiss(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, OutputCostPerToken: 0.000015, + InputCostPerToken: bifrost.Ptr(0.000005), OutputCostPerToken: bifrost.Ptr(0.000015), }, makeKey("text-embedding-3-small", "openai", "embedding"): { Model: "text-embedding-3-small", Provider: "openai", Mode: "embedding", - InputCostPerToken: 0.00000002, + InputCostPerToken: bifrost.Ptr(0.00000002), }, }) @@ -917,9 +922,9 @@ func TestCalculateCost_SemanticCacheMiss(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 1000, CompletionTokens: 500, TotalTokens: 1500}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", CacheDebug: &schemas.BifrostCacheDebug{ CacheHit: false, ProviderUsed: &embProvider, @@ -930,7 +935,7 @@ func TestCalculateCost_SemanticCacheMiss(t *testing.T) { }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Base cost: 1000*0.000005 + 500*0.000015 = 0.005 + 0.0075 = 0.0125 // Embedding cost: 500 * 0.00000002 = 0.00001 // Total: 0.01251 @@ -951,7 +956,7 @@ func TestCalculateCost_SemanticCacheHitNoEmbeddingInfo(t *testing.T) { }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.Equal(t, 0.0, cost) } @@ -961,7 +966,7 @@ func TestCalculateCost_SemanticCacheHitNoEmbeddingInfo(t *testing.T) { func TestCalculateCost_NilResponse(t *testing.T) { mc := testCatalogWithPricing(nil) - assert.Equal(t, 0.0, mc.CalculateCost(nil)) + assert.Equal(t, 0.0, mc.CalculateCost(nil, nil)) } func TestCalculateCost_ProviderComputedCostPassthrough(t *testing.T) { @@ -978,7 +983,7 @@ func TestCalculateCost_ProviderComputedCostPassthrough(t *testing.T) { }, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.Equal(t, 0.99, cost) } @@ -988,7 +993,7 @@ func TestCalculateCost_NoUsageData(t *testing.T) { }) resp := makeChatResponse(schemas.OpenAI, "gpt-4o", nil) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.Equal(t, 0.0, cost) } @@ -997,9 +1002,9 @@ func TestCalculateCost_ChatCompletion_GPT4o(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): { Model: "gpt-4o", Provider: "openai", Mode: "chat", - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - CacheReadInputTokenCost: ptr(0.0000005), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + CacheReadInputTokenCost: bifrost.Ptr(0.0000005), }, }) @@ -1009,7 +1014,7 @@ func TestCalculateCost_ChatCompletion_GPT4o(t *testing.T) { TotalTokens: 12000, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // 10000*0.000005 + 2000*0.000015 = 0.05 + 0.03 = 0.08 assert.InDelta(t, 0.08, cost, 1e-12) } @@ -1019,12 +1024,12 @@ func TestCalculateCost_ChatCompletion_Claude35Sonnet_WithCache(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("anthropic.claude-3-5-sonnet-20241022-v2:0", "bedrock", "chat"): { Model: "anthropic.claude-3-5-sonnet-20241022-v2:0", Provider: "bedrock", Mode: "chat", - InputCostPerToken: 0.000003, - OutputCostPerToken: 0.000015, - CacheReadInputTokenCost: ptr(0.0000003), - CacheCreationInputTokenCost: ptr(0.00000375), - InputCostPerTokenAbove200kTokens: ptr(0.000006), - OutputCostPerTokenAbove200kTokens: ptr(0.00003), + InputCostPerToken: bifrost.Ptr(0.000003), + OutputCostPerToken: bifrost.Ptr(0.000015), + CacheReadInputTokenCost: bifrost.Ptr(0.0000003), + CacheCreationInputTokenCost: bifrost.Ptr(0.00000375), + InputCostPerTokenAbove200kTokens: bifrost.Ptr(0.000006), + OutputCostPerTokenAbove200kTokens: bifrost.Ptr(0.00003), }, }) @@ -1038,7 +1043,7 @@ func TestCalculateCost_ChatCompletion_Claude35Sonnet_WithCache(t *testing.T) { }, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Both cached read and write tokens are input-side deductions from promptTokens. // Input: (5000-3000-500)*0.000003 + 3000*0.0000003 + 500*0.00000375 = 0.0045 + 0.0009 + 0.001875 = 0.007275 // Output: 1000*0.000015 = 0.015 @@ -1051,8 +1056,8 @@ func TestCalculateCost_Embedding(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("amazon.titan-embed-text-v1", "bedrock", "embedding"): { Model: "amazon.titan-embed-text-v1", Provider: "bedrock", Mode: "embedding", - InputCostPerToken: 0.0000001, - OutputCostPerToken: 0, + InputCostPerToken: bifrost.Ptr(0.0000001), + OutputCostPerToken: bifrost.Ptr(0.0), }, }) @@ -1061,7 +1066,7 @@ func TestCalculateCost_Embedding(t *testing.T) { TotalTokens: 10000, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // 10000 * 0.0000001 = 0.001 assert.InDelta(t, 0.001, cost, 1e-12) } @@ -1070,8 +1075,8 @@ func TestCalculateCost_Rerank(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("amazon.rerank-v1:0", "bedrock", "rerank"): { Model: "amazon.rerank-v1:0", Provider: "bedrock", Mode: "rerank", - InputCostPerToken: 0, - OutputCostPerToken: 0, + InputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerToken: bifrost.Ptr(0.0), }, }) @@ -1080,7 +1085,7 @@ func TestCalculateCost_Rerank(t *testing.T) { TotalTokens: 500, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.Equal(t, 0.0, cost) } @@ -1089,7 +1094,7 @@ func TestCalculateCost_ImageGeneration(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("dall-e-3", "aiml", "image_generation"): { Model: "dall-e-3", Provider: "aiml", Mode: "image_generation", - OutputCostPerImage: ptr(0.052), + OutputCostPerImage: bifrost.Ptr(0.052), }, }) @@ -1097,7 +1102,7 @@ func TestCalculateCost_ImageGeneration(t *testing.T) { OutputTokensDetails: &schemas.ImageTokenDetails{NImages: 3}, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // 3 * 0.052 = 0.156 assert.InDelta(t, 0.156, cost, 1e-12) } @@ -1112,14 +1117,14 @@ func TestCalculateCost_StreamRequestTypeNormalized(t *testing.T) { ChatResponse: &schemas.BifrostChatResponse{ Usage: &schemas.BifrostLLMUsage{PromptTokens: 1000, CompletionTokens: 500, TotalTokens: 1500}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o", + RequestType: schemas.ChatCompletionStreamRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.InDelta(t, 0.0125, cost, 1e-12) } @@ -1128,7 +1133,7 @@ func TestCalculateCost_NoPricingData(t *testing.T) { resp := makeChatResponse(schemas.OpenAI, "unknown-model", &schemas.BifrostLLMUsage{ PromptTokens: 1000, CompletionTokens: 500, TotalTokens: 1500, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.Equal(t, 0.0, cost) } @@ -1140,57 +1145,59 @@ func TestGetPricing_DirectLookup(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), }) - p, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) - assert.Equal(t, 0.000005, p.InputCostPerToken) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) + assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) } func TestGetPricing_GeminiFallsBackToVertex(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gemini-2.0-flash", "vertex", "chat"): { Model: "gemini-2.0-flash", Provider: "vertex", Mode: "chat", - InputCostPerToken: 0.0000001, OutputCostPerToken: 0.0000004, + InputCostPerToken: bifrost.Ptr(0.0000001), OutputCostPerToken: bifrost.Ptr(0.0000004), }, }) - p, ok := mc.getPricing("gemini-2.0-flash", "gemini", schemas.ChatCompletionRequest) - require.True(t, ok) - assert.Equal(t, 0.0000001, p.InputCostPerToken) + p := mc.resolvePricing("gemini", "gemini-2.0-flash", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"}) + assert.Equal(t, 0.0000001, derefF(p.InputCostPerToken)) } func TestGetPricing_VertexStripsProviderPrefix(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gemini-2.0-flash", "vertex", "chat"): chatPricing(0.0000001, 0.0000004), }) - p, ok := mc.getPricing("google/gemini-2.0-flash", "vertex", schemas.ChatCompletionRequest) - require.True(t, ok) - assert.Equal(t, 0.0000001, p.InputCostPerToken) + p := mc.resolvePricing("vertex", "google/gemini-2.0-flash", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "vertex"}) + assert.Equal(t, 0.0000001, derefF(p.InputCostPerToken)) } func TestGetPricing_BedrockAddsAnthropicPrefix(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("anthropic.claude-3-5-sonnet-20241022-v2:0", "bedrock", "chat"): chatPricing(0.000003, 0.000015), }) - p, ok := mc.getPricing("claude-3-5-sonnet-20241022-v2:0", "bedrock", schemas.ChatCompletionRequest) - require.True(t, ok) - assert.Equal(t, 0.000003, p.InputCostPerToken) + p := mc.resolvePricing("bedrock", "claude-3-5-sonnet-20241022-v2:0", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "bedrock"}) + assert.Equal(t, 0.000003, derefF(p.InputCostPerToken)) } func TestGetPricing_ResponsesFallsBackToChat(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), }) - p, ok := mc.getPricing("gpt-4o", "openai", schemas.ResponsesRequest) - require.True(t, ok) - assert.Equal(t, 0.000005, p.InputCostPerToken) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"}) + assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) } func TestGetPricing_ResponsesStreamFallsBackToChat(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), }) - p, ok := mc.getPricing("gpt-4o", "openai", schemas.ResponsesStreamRequest) - require.True(t, ok) - assert.Equal(t, 0.000005, p.InputCostPerToken) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesStreamRequest, PricingLookupScopes{Provider: "openai"}) + assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) +} + +func TestGetPricing_RealtimeFallsBackToChat(t *testing.T) { + mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ + makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), + }) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.RealtimeRequest, PricingLookupScopes{Provider: "openai"}) + assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) } func TestGetPricing_GeminiResponsesFallsBackToVertexChat(t *testing.T) { @@ -1198,15 +1205,14 @@ func TestGetPricing_GeminiResponsesFallsBackToVertexChat(t *testing.T) { makeKey("gemini-2.0-flash", "vertex", "chat"): chatPricing(0.0000001, 0.0000004), }) // gemini provider + responses request β†’ try vertex + responses β†’ try vertex + chat - p, ok := mc.getPricing("gemini-2.0-flash", "gemini", schemas.ResponsesRequest) - require.True(t, ok) - assert.Equal(t, 0.0000001, p.InputCostPerToken) + p := mc.resolvePricing("gemini", "gemini-2.0-flash", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "gemini"}) + assert.Equal(t, 0.0000001, derefF(p.InputCostPerToken)) } func TestGetPricing_NotFound(t *testing.T) { mc := testCatalogWithPricing(nil) - _, ok := mc.getPricing("nonexistent", "openai", schemas.ChatCompletionRequest) - assert.False(t, ok) + p := mc.resolvePricing("openai", "nonexistent", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) + assert.Nil(t, p) } // ========================================================================= @@ -1219,26 +1225,27 @@ func TestResolvePricing_DeploymentFallback(t *testing.T) { }) // Model not found directly, but deployment matches - p := mc.resolvePricing("openai", "gpt-4o-custom", "my-deployment", schemas.ChatCompletionRequest) + p := mc.resolvePricing("openai", "gpt-4o-custom", "my-deployment", schemas.ChatCompletionRequest, PricingLookupScopes{}) require.NotNil(t, p) - assert.Equal(t, 0.000005, p.InputCostPerToken) + assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) } -func TestResolvePricing_ModelFoundDirectly(t *testing.T) { +func TestResolvePricing_ResolvedModelHasPriority(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), makeKey("my-deployment", "openai", "chat"): chatPricing(0.000001, 0.000002), }) - // Model found directly β€” doesn't fall back to deployment - p := mc.resolvePricing("openai", "gpt-4o", "my-deployment", schemas.ChatCompletionRequest) + // Resolved model ("my-deployment") is looked up first and has priority + // over the originally requested model ("gpt-4o"). + p := mc.resolvePricing("openai", "gpt-4o", "my-deployment", schemas.ChatCompletionRequest, PricingLookupScopes{}) require.NotNil(t, p) - assert.Equal(t, 0.000005, p.InputCostPerToken) + assert.Equal(t, 0.000001, derefF(p.InputCostPerToken)) } func TestResolvePricing_NothingFound(t *testing.T) { mc := testCatalogWithPricing(nil) - p := mc.resolvePricing("openai", "unknown", "", schemas.ChatCompletionRequest) + p := mc.resolvePricing("openai", "unknown", "", schemas.ChatCompletionRequest, PricingLookupScopes{}) assert.Nil(t, p) } @@ -1258,6 +1265,7 @@ func TestNormalizeStreamRequestType(t *testing.T) { {schemas.TranscriptionStreamRequest, schemas.TranscriptionRequest}, {schemas.ImageGenerationStreamRequest, schemas.ImageGenerationRequest}, {schemas.ImageEditStreamRequest, schemas.ImageEditRequest}, + {schemas.RealtimeRequest, schemas.RealtimeRequest}, // realtime is its own base type {schemas.ChatCompletionRequest, schemas.ChatCompletionRequest}, // non-stream unchanged {schemas.EmbeddingRequest, schemas.EmbeddingRequest}, // non-stream unchanged } @@ -1327,14 +1335,14 @@ func TestCalculateCost_200kTier_EndToEnd(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock", "chat"): { Model: "anthropic.claude-3-5-sonnet-20240620-v1:0", Provider: "bedrock", Mode: "chat", - InputCostPerToken: 0.000003, - OutputCostPerToken: 0.000015, - InputCostPerTokenAbove200kTokens: ptr(0.000006), - OutputCostPerTokenAbove200kTokens: ptr(0.00003), - CacheReadInputTokenCost: ptr(0.0000003), - CacheCreationInputTokenCost: ptr(0.00000375), - CacheReadInputTokenCostAbove200kTokens: ptr(0.0000006), - CacheCreationInputTokenCostAbove200kTokens: ptr(0.0000075), + InputCostPerToken: bifrost.Ptr(0.000003), + OutputCostPerToken: bifrost.Ptr(0.000015), + InputCostPerTokenAbove200kTokens: bifrost.Ptr(0.000006), + OutputCostPerTokenAbove200kTokens: bifrost.Ptr(0.00003), + CacheReadInputTokenCost: bifrost.Ptr(0.0000003), + CacheCreationInputTokenCost: bifrost.Ptr(0.00000375), + CacheReadInputTokenCostAbove200kTokens: bifrost.Ptr(0.0000006), + CacheCreationInputTokenCostAbove200kTokens: bifrost.Ptr(0.0000075), }, }) @@ -1344,7 +1352,7 @@ func TestCalculateCost_200kTier_EndToEnd(t *testing.T) { TotalTokens: 210000, // Above 200k }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // Tiered rate: input=0.000006, output=0.00003 // 190000*0.000006 + 20000*0.00003 = 1.14 + 0.6 = 1.74 assert.InDelta(t, 1.74, cost, 1e-9) @@ -1365,14 +1373,14 @@ func TestCalculateCost_ProviderCostZeroTotalStillCalculates(t *testing.T) { }, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.InDelta(t, 0.0125, cost, 1e-12) } func TestCalculateCost_AllCachedTokens(t *testing.T) { // All prompt tokens are from cache p := chatPricing(0.000005, 0.000015) - p.CacheReadInputTokenCost = ptr(0.0000005) + p.CacheReadInputTokenCost = bifrost.Ptr(0.0000005) usage := &schemas.BifrostLLMUsage{ PromptTokens: 1000, @@ -1398,8 +1406,8 @@ func TestCalculateCost_ImageGeneration_NilUsage_PerImagePricing(t *testing.T) { Model: "dall-e-3", Provider: "openai", Mode: "image_generation", - InputCostPerToken: 0, - OutputCostPerImage: ptr(0.04), + InputCostPerToken: bifrost.Ptr(0.0), + OutputCostPerImage: bifrost.Ptr(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1407,7 +1415,7 @@ func TestCalculateCost_ImageGeneration_NilUsage_PerImagePricing(t *testing.T) { }) resp := makeImageResponse("openai", "dall-e-3", nil) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // 1 image * $0.04 = $0.04 assert.InDelta(t, 0.04, cost, 1e-12) } @@ -1418,8 +1426,8 @@ func TestCalculateCost_ImageGeneration_NilUsage_InputAndOutputPerImage(t *testin Model: "test-image-model", Provider: "test", Mode: "image_generation", - InputCostPerImage: ptr(0.01), - OutputCostPerImage: ptr(0.04), + InputCostPerImage: bifrost.Ptr(0.01), + OutputCostPerImage: bifrost.Ptr(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1427,7 +1435,7 @@ func TestCalculateCost_ImageGeneration_NilUsage_InputAndOutputPerImage(t *testin }) resp := makeImageResponse("test", "test-image-model", nil) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // NumInputImages is 0 (not populated from request), so only output pricing applies // 1 output image * $0.04 = $0.04 assert.InDelta(t, 0.04, cost, 1e-12) @@ -1439,8 +1447,8 @@ func TestCalculateCost_ImageGeneration_WithInputImages(t *testing.T) { Model: "gpt-image-1", Provider: "openai", Mode: "image_generation", - InputCostPerImage: ptr(0.01), - OutputCostPerImage: ptr(0.04), + InputCostPerImage: bifrost.Ptr(0.01), + OutputCostPerImage: bifrost.Ptr(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1450,7 +1458,7 @@ func TestCalculateCost_ImageGeneration_WithInputImages(t *testing.T) { resp := makeImageResponse("openai", "gpt-image-1", &schemas.ImageUsage{ NumInputImages: 2, }) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // 2 input images * $0.01 + 1 output image * $0.04 = $0.06 assert.InDelta(t, 0.06, cost, 1e-12) } @@ -1461,7 +1469,7 @@ func TestCalculateCost_ImageGeneration_OutputCountFromData(t *testing.T) { Model: "dall-e-3", Provider: "openai", Mode: "image_generation", - OutputCostPerImage: ptr(0.04), + OutputCostPerImage: bifrost.Ptr(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1476,13 +1484,13 @@ func TestCalculateCost_ImageGeneration_OutputCountFromData(t *testing.T) { {URL: "https://example.com/img3.png", Index: 2}, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: "openai", - ModelRequested: "dall-e-3", + RequestType: schemas.ImageGenerationRequest, + Provider: "openai", + OriginalModelRequested: "dall-e-3", }, }, } - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // 3 output images * $0.04 = $0.12 assert.InDelta(t, 0.12, cost, 1e-12) } @@ -1493,8 +1501,8 @@ func TestCalculateCost_ImageGeneration_NilUsage_NoPerImagePricing(t *testing.T) Model: "token-only-model", Provider: "test", Mode: "image_generation", - InputCostPerToken: 0.000001, - OutputCostPerToken: 0.000002, + InputCostPerToken: bifrost.Ptr(0.000001), + OutputCostPerToken: bifrost.Ptr(0.000002), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1502,7 +1510,7 @@ func TestCalculateCost_ImageGeneration_NilUsage_NoPerImagePricing(t *testing.T) }) resp := makeImageResponse("test", "token-only-model", nil) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) // No per-image pricing and all tokens are zero β†’ 0 assert.InDelta(t, 0.0, cost, 1e-12) } @@ -1513,7 +1521,7 @@ func TestCalculateCost_ImageGeneration_EmptyUsage_PerImagePricing(t *testing.T) Model: "dall-e-3", Provider: "openai", Mode: "image_generation", - OutputCostPerImage: ptr(0.04), + OutputCostPerImage: bifrost.Ptr(0.04), } mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ @@ -1521,16 +1529,16 @@ func TestCalculateCost_ImageGeneration_EmptyUsage_PerImagePricing(t *testing.T) }) resp := makeImageResponse("openai", "dall-e-3", &schemas.ImageUsage{}) - cost := mc.CalculateCost(resp) + cost := mc.CalculateCost(resp, nil) assert.InDelta(t, 0.04, cost, 1e-12) } func TestComputeImageCost_MixedInputTokensOutputPerImage(t *testing.T) { // Input has tokens (text prompt), output has no tokens but per-image pricing p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - OutputCostPerImage: ptr(0.04), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + OutputCostPerImage: bifrost.Ptr(0.04), } usage := &schemas.ImageUsage{ InputTokens: 500, @@ -1545,9 +1553,9 @@ func TestComputeImageCost_MixedInputTokensOutputPerImage(t *testing.T) { func TestComputeImageCost_MixedInputPerImageOutputTokens(t *testing.T) { // Input has no tokens but per-image count, output has tokens p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerImage: ptr(0.01), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + InputCostPerImage: bifrost.Ptr(0.01), } usage := &schemas.ImageUsage{ NumInputImages: 3, @@ -1562,10 +1570,10 @@ func TestComputeImageCost_MixedInputPerImageOutputTokens(t *testing.T) { func TestComputeImageCost_BothHaveTokens_IgnoresPerImage(t *testing.T) { // Both sides have tokens β€” per-image pricing is ignored p := configstoreTables.TableModelPricing{ - InputCostPerToken: 0.000005, - OutputCostPerToken: 0.000015, - InputCostPerImage: ptr(0.01), - OutputCostPerImage: ptr(0.04), + InputCostPerToken: bifrost.Ptr(0.000005), + OutputCostPerToken: bifrost.Ptr(0.000015), + InputCostPerImage: bifrost.Ptr(0.01), + OutputCostPerImage: bifrost.Ptr(0.04), } usage := &schemas.ImageUsage{ InputTokens: 200, diff --git a/framework/modelcatalog/sync.go b/framework/modelcatalog/sync.go index 29c88542a6..69bae551d4 100644 --- a/framework/modelcatalog/sync.go +++ b/framework/modelcatalog/sync.go @@ -414,5 +414,3 @@ func (mc *ModelCatalog) loadModelParametersFromURL(ctx context.Context) (map[str mc.logger.Debug("model-parameters-sync: successfully downloaded and parsed %d model parameters records", len(paramsData)) return paramsData, nil } - - diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index 5326c32906..ab9c01211a 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -3,6 +3,7 @@ package modelcatalog import ( "strings" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) @@ -36,7 +37,7 @@ func normalizeRequestType(reqType schemas.RequestType) string { baseType = "completion" case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: baseType = "chat" - case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.RealtimeRequest: baseType = "responses" case schemas.EmbeddingRequest: baseType = "embedding" @@ -67,6 +68,8 @@ func normalizeStreamRequestType(rt schemas.RequestType) schemas.RequestType { return schemas.ChatCompletionRequest case schemas.ResponsesStreamRequest: return schemas.ResponsesRequest + case schemas.RealtimeRequest: + return schemas.RealtimeRequest case schemas.SpeechStreamRequest: return schemas.SpeechRequest case schemas.TranscriptionStreamRequest: @@ -172,15 +175,7 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) // convertTableModelPricingToPricingData converts the TableModelPricing struct to a PricingEntry struct func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry { - return &PricingEntry{ - BaseModel: pricing.BaseModel, - Provider: pricing.Provider, - Mode: pricing.Mode, - ContextLength: pricing.ContextLength, - MaxInputTokens: pricing.MaxInputTokens, - MaxOutputTokens: pricing.MaxOutputTokens, - Architecture: pricing.Architecture, - + options := PricingOptions{ // Costs - Text InputCostPerToken: pricing.InputCostPerToken, OutputCostPerToken: pricing.OutputCostPerToken, @@ -243,4 +238,34 @@ func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModel SearchContextCostPerQuery: pricing.SearchContextCostPerQuery, CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession, } + return &PricingEntry{ + BaseModel: pricing.BaseModel, + Provider: pricing.Provider, + Mode: pricing.Mode, + ContextLength: pricing.ContextLength, + MaxInputTokens: pricing.MaxInputTokens, + MaxOutputTokens: pricing.MaxOutputTokens, + Architecture: pricing.Architecture, + PricingOptions: options, + } +} + +// convertTablePricingOverrideToPricingOverride converts a TablePricingOverride to a PricingOverride. +func convertTablePricingOverrideToPricingOverride(override *configstoreTables.TablePricingOverride) (PricingOverride, error) { + var options PricingOptions + if err := sonic.Unmarshal([]byte(override.PricingPatchJSON), &options); err != nil { + return PricingOverride{}, err + } + return PricingOverride{ + ID: override.ID, + Name: override.Name, + ScopeKind: ScopeKind(override.ScopeKind), + VirtualKeyID: override.VirtualKeyID, + ProviderID: override.ProviderID, + ProviderKeyID: override.ProviderKeyID, + MatchType: MatchType(override.MatchType), + Pattern: override.Pattern, + RequestTypes: override.RequestTypes, + Options: options, + }, nil } diff --git a/framework/oauth2/main.go b/framework/oauth2/main.go index 8666eb3697..44667932aa 100644 --- a/framework/oauth2/main.go +++ b/framework/oauth2/main.go @@ -680,3 +680,349 @@ func generateSecureRandomString(length int) (string, error) { } return base64.URLEncoding.EncodeToString(bytes)[:length], nil } + +// generateSessionToken generates a cryptographically secure opaque session token (hex-encoded) +func generateSessionToken() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate session token: %w", err) + } + return fmt.Sprintf("%x", bytes), nil +} + +// ---------- Per-User OAuth Methods ---------- + +// InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL. +// It reuses the template OAuth config (which holds client_id, token_url, etc.) to build the flow. +func (p *OAuth2Provider) InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*schemas.OAuth2FlowInitiation, string, error) { + // Load the template OAuth config + templateConfig, err := p.configStore.GetOauthConfigByID(ctx, oauthConfigID) + if err != nil { + return nil, "", fmt.Errorf("failed to load template oauth config: %w", err) + } + if templateConfig == nil { + return nil, "", schemas.ErrOAuth2ConfigNotFound + } + + // Generate state token for CSRF protection + state, err := generateSecureRandomString(32) + if err != nil { + return nil, "", fmt.Errorf("failed to generate state token: %w", err) + } + + // Generate PKCE challenge + codeVerifier, codeChallenge, err := GeneratePKCEChallenge() + if err != nil { + return nil, "", fmt.Errorf("failed to generate PKCE challenge: %w", err) + } + + // Parse scopes from template config + var scopes []string + if templateConfig.Scopes != "" { + json.Unmarshal([]byte(templateConfig.Scopes), &scopes) + } + + // Create per-user OAuth session + sessionID := uuid.New().String() + expiresAt := time.Now().Add(15 * time.Minute) + + // Propagate identity from context so the callback can link the token to the user + virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) + userID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceUserID).(string) + // For OSS: prefer X-Bf-User-Id header as user identity + if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" { + userID = mcpUserID + } + + // If a Bifrost MCP session token is present in context, reuse it as the session token + // so the MCP server token is stored under the same key used for subsequent lookups. + // Otherwise generate a fresh token. + sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) + if sessionToken == "" { + sessionToken, err = generateSessionToken() + if err != nil { + return nil, "", fmt.Errorf("failed to generate session token: %w", err) + } + } + var vkId *string + if virtualKeyID != "" { + vkId = &virtualKeyID + } + var uid *string + if userID != "" { + uid = &userID + } + session := &tables.TableOauthUserSession{ + ID: sessionID, + MCPClientID: mcpClientID, + OauthConfigID: oauthConfigID, + State: state, + RedirectURI: redirectURI, + CodeVerifier: codeVerifier, + SessionToken: sessionToken, + VirtualKeyID: vkId, + UserID: uid, + Status: "pending", + ExpiresAt: expiresAt, + } + + if err := p.configStore.CreateOauthUserSession(ctx, session); err != nil { + return nil, "", fmt.Errorf("failed to create per-user oauth session: %w", err) + } + + // Build authorize URL with PKCE + authURL := p.buildAuthorizeURLWithPKCE( + templateConfig.AuthorizeURL, + templateConfig.ClientID, + redirectURI, + state, + codeChallenge, + scopes, + ) + + logger.Debug("Per-user OAuth flow initiated: session_id=%s, mcp_client_id=%s", sessionID, mcpClientID) + + return &schemas.OAuth2FlowInitiation{ + OauthConfigID: oauthConfigID, + AuthorizeURL: authURL, + State: state, + ExpiresAt: expiresAt, + }, sessionID, nil +} + +// CompleteUserOAuthFlow handles the OAuth callback for a per-user flow. +// It looks up the session by state, exchanges code for tokens, and returns a session token. +func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error) { + // Atomically claim session by state to prevent concurrent callback races + session, err := p.configStore.ClaimOauthUserSessionByState(ctx, state) + if err != nil { + return "", fmt.Errorf("failed to claim per-user oauth session: %w", err) + } + if session == nil { + // State not found or already claimed β€” not a per-user session + return "", schemas.ErrOAuth2NotPerUserSession + } + + // Check expiry + if time.Now().After(session.ExpiresAt) { + session.Status = "expired" + p.configStore.UpdateOauthUserSession(ctx, session) + return "", fmt.Errorf("per-user oauth flow expired") + } + + // Load template OAuth config for token_url, client_id, etc. + templateConfig, err := p.configStore.GetOauthConfigByID(ctx, session.OauthConfigID) + if err != nil || templateConfig == nil { + session.Status = "failed" + p.configStore.UpdateOauthUserSession(ctx, session) + return "", fmt.Errorf("failed to load template oauth config: %w", err) + } + // Exchange code for tokens with PKCE verifier + // Use the redirect URI stored in the session (same one used in authorize step) + // to satisfy OAuth spec requirement that redirect_uri must match + redirectURI := session.RedirectURI + if redirectURI == "" { + redirectURI = templateConfig.RedirectURI + } + tokenResponse, err := p.exchangeCodeForTokensWithPKCE( + templateConfig.TokenURL, + code, + templateConfig.ClientID, + templateConfig.ClientSecret, + redirectURI, + session.CodeVerifier, + ) + if err != nil { + session.Status = "failed" + p.configStore.UpdateOauthUserSession(ctx, session) + return "", fmt.Errorf("per-user token exchange failed: %w", err) + } + + // Use existing session token if set (e.g., Bifrost session ID from MCP spec OAuth flow), + // otherwise generate a new one (for standalone per-user OAuth). + sessionToken := session.SessionToken + if sessionToken == "" { + sessionToken, err = generateSessionToken() + if err != nil { + session.Status = "failed" + p.configStore.UpdateOauthUserSession(ctx, session) + return "", err + } + } + + // Parse scopes + var scopes []string + if tokenResponse.Scope != "" { + scopes = strings.Split(tokenResponse.Scope, " ") + } + scopesJSON, _ := json.Marshal(scopes) + + // Create per-user OAuth token record, propagating identity from session + tokenRecord := &tables.TableOauthUserToken{ + ID: uuid.New().String(), + SessionToken: sessionToken, + VirtualKeyID: session.VirtualKeyID, + UserID: session.UserID, + MCPClientID: session.MCPClientID, + OauthConfigID: session.OauthConfigID, + AccessToken: strings.TrimSpace(tokenResponse.AccessToken), + RefreshToken: strings.TrimSpace(tokenResponse.RefreshToken), + TokenType: tokenResponse.TokenType, + ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second), + Scopes: string(scopesJSON), + } + if err := p.configStore.CreateOauthUserToken(ctx, tokenRecord); err != nil { + return "", fmt.Errorf("failed to create per-user oauth token: %w", err) + } + + // Update session with session token and mark as authorized + session.SessionToken = sessionToken + session.Status = "authorized" + if err := p.configStore.UpdateOauthUserSession(ctx, session); err != nil { + return "", fmt.Errorf("failed to update per-user oauth session: %w", err) + } + + logger.Debug("Per-user OAuth flow completed: session_id=%s, mcp_client_id=%s", session.ID, session.MCPClientID) + + return sessionToken, nil +} + +// GetUserAccessToken retrieves the access token for a per-user OAuth session. +// If the token is expired, it automatically attempts a refresh. +func (p *OAuth2Provider) GetUserAccessToken(ctx context.Context, sessionToken string) (string, error) { + token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken) + if err != nil { + return "", fmt.Errorf("failed to load per-user oauth token: %w", err) + } + if token == nil { + return "", fmt.Errorf("per-user oauth token not found for session") + } + + // Check if token is expired + if time.Now().After(token.ExpiresAt) { + if err := p.RefreshUserAccessToken(ctx, sessionToken); err != nil { + return "", fmt.Errorf("per-user token expired and refresh failed: %w", err) + } + // Reload token after refresh + token, err = p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken) + if err != nil || token == nil { + return "", fmt.Errorf("failed to reload per-user token after refresh") + } + } + + accessToken := strings.TrimSpace(token.AccessToken) + if accessToken == "" { + return "", fmt.Errorf("per-user access token is empty after sanitization") + } + return accessToken, nil +} + +// GetUserAccessTokenByIdentity retrieves the upstream access token for a user +// identified by virtualKeyID, userID, or sessionToken (fallback), for a specific +// MCP client. Identity-based lookups persist tokens across sessions so users don't +// need to re-authenticate with upstream providers on reconnect. +func (p *OAuth2Provider) GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error) { + token, err := p.configStore.GetOauthUserTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, mcpClientID) + if err != nil { + return "", fmt.Errorf("failed to load per-user oauth token by identity: %w", err) + } + if token == nil { + return "", schemas.ErrOAuth2TokenNotFound + } + + // Check if token is expired β€” attempt refresh + if time.Now().After(token.ExpiresAt) { + if token.SessionToken != "" { + if err := p.RefreshUserAccessToken(ctx, token.SessionToken); err != nil { + return "", fmt.Errorf("per-user token expired and refresh failed: %w", err) + } + // Reload after refresh + token, err = p.configStore.GetOauthUserTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, mcpClientID) + if err != nil || token == nil { + return "", fmt.Errorf("failed to reload per-user token after refresh") + } + } else { + return "", fmt.Errorf("per-user token expired and no session token available for refresh") + } + } + + accessToken := strings.TrimSpace(token.AccessToken) + if accessToken == "" { + return "", fmt.Errorf("per-user access token is empty after sanitization") + } + return accessToken, nil +} + +// RefreshUserAccessToken refreshes a per-user OAuth access token. +func (p *OAuth2Provider) RefreshUserAccessToken(ctx context.Context, sessionToken string) error { + p.mu.Lock() + defer p.mu.Unlock() + + token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken) + if err != nil || token == nil { + return fmt.Errorf("per-user oauth token not found: %w", err) + } + + if token.RefreshToken == "" { + return fmt.Errorf("no refresh token available for per-user oauth session") + } + + // Load template OAuth config for token_url, client_id, etc. + templateConfig, err := p.configStore.GetOauthConfigByID(ctx, token.OauthConfigID) + if err != nil || templateConfig == nil { + return fmt.Errorf("failed to load template oauth config for refresh: %w", err) + } + + // Exchange refresh token + newTokenResponse, err := p.exchangeRefreshToken( + templateConfig.TokenURL, + templateConfig.ClientID, + templateConfig.ClientSecret, + token.RefreshToken, + ) + if err != nil { + return fmt.Errorf("per-user token refresh failed: %w", err) + } + + // Update token + now := time.Now() + token.AccessToken = strings.TrimSpace(newTokenResponse.AccessToken) + if newTokenResponse.RefreshToken != "" { + token.RefreshToken = strings.TrimSpace(newTokenResponse.RefreshToken) + } + token.ExpiresAt = now.Add(time.Duration(newTokenResponse.ExpiresIn) * time.Second) + token.LastRefreshedAt = &now + + if err := p.configStore.UpdateOauthUserToken(ctx, token); err != nil { + return fmt.Errorf("failed to update per-user token after refresh: %w", err) + } + + logger.Debug("Per-user OAuth token refreshed: session_token=...%s", sessionToken[len(sessionToken)-4:]) + return nil +} + +// RevokeUserToken revokes a per-user OAuth token and marks the session as revoked. +func (p *OAuth2Provider) RevokeUserToken(ctx context.Context, sessionToken string) error { + p.mu.Lock() + defer p.mu.Unlock() + + token, err := p.configStore.GetOauthUserTokenBySessionToken(ctx, sessionToken) + if err != nil || token == nil { + return fmt.Errorf("per-user oauth token not found: %w", err) + } + + // Delete the token + if err := p.configStore.DeleteOauthUserToken(ctx, token.ID); err != nil { + return fmt.Errorf("failed to delete per-user oauth token: %w", err) + } + + // Update session status + session, err := p.configStore.GetOauthUserSessionBySessionToken(ctx, sessionToken) + if err == nil && session != nil { + session.Status = "revoked" + p.configStore.UpdateOauthUserSession(ctx, session) + } + + logger.Debug("Per-user OAuth token revoked: session_token=...%s", sessionToken[len(sessionToken)-4:]) + return nil +} diff --git a/framework/streaming/accumulator_test.go b/framework/streaming/accumulator_test.go index f7df86b565..18eb43f71b 100644 --- a/framework/streaming/accumulator_test.go +++ b/framework/streaming/accumulator_test.go @@ -64,10 +64,10 @@ func TestChatStreamingFinalChunkNoDeadlock(t *testing.T) { TotalTokens: 150, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: schemas.Anthropic, - ModelRequested: "claude-opus-4", - ChunkIndex: 9, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: schemas.Anthropic, + OriginalModelRequested: "claude-opus-4", + ChunkIndex: 9, }, }, } @@ -140,10 +140,10 @@ func TestResponsesStreamingFinalChunkNoDeadlock(t *testing.T) { OutputTokens: 50, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: schemas.Anthropic, - ModelRequested: "claude-opus-4", - ChunkIndex: 4, + RequestType: schemas.ResponsesStreamRequest, + Provider: schemas.Anthropic, + OriginalModelRequested: "claude-opus-4", + ChunkIndex: 4, }, }, } @@ -488,10 +488,10 @@ func TestAudioStreamingFinalChunkNoDeadlock(t *testing.T) { TotalTokens: 150, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: schemas.OpenAI, - ModelRequested: "tts-1", - ChunkIndex: 7, + RequestType: schemas.SpeechStreamRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "tts-1", + ChunkIndex: 7, }, }, } @@ -559,10 +559,10 @@ func TestTranscriptionStreamingFinalChunkNoDeadlock(t *testing.T) { TranscriptionResponse: &schemas.BifrostTranscriptionResponse{ Text: "Complete transcription", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: schemas.OpenAI, - ModelRequested: "whisper-1", - ChunkIndex: 5, + RequestType: schemas.TranscriptionStreamRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "whisper-1", + ChunkIndex: 5, }, }, } diff --git a/framework/streaming/audio.go b/framework/streaming/audio.go index d36fb47d36..0390ea5aaf 100644 --- a/framework/streaming/audio.go +++ b/framework/streaming/audio.go @@ -8,6 +8,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" ) // buildCompleteMessageFromAudioStreamChunks builds a complete message from accumulated audio chunks @@ -120,7 +121,7 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) // For audio, all the data comes in the final chunk chunk := a.getAudioStreamChunk() @@ -145,7 +146,7 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, chunk.ChunkIndex = result.SpeechStreamResponse.ExtraFields.ChunkIndex if isFinalChunk { if a.pricingManager != nil { - cost := a.pricingManager.CalculateCost(result) + cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug @@ -176,21 +177,23 @@ func (a *Accumulator) processAudioStreamingResponse(ctx *schemas.BifrostContext, rawRequest = result.SpeechStreamResponse.ExtraFields.RawRequest } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeAudio, - Model: model, - Provider: provider, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeAudio, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Provider: provider, + Data: data, + RawRequest: &rawRequest, }, nil } // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeAudio, - Model: model, - Provider: provider, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeAudio, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Provider: provider, + Data: nil, }, nil } diff --git a/framework/streaming/chat.go b/framework/streaming/chat.go index dafd170902..6602e9a21a 100644 --- a/framework/streaming/chat.go +++ b/framework/streaming/chat.go @@ -8,6 +8,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" ) // deepCopyChatStreamDelta creates a deep copy of ChatStreamResponseChoiceDelta @@ -463,7 +464,7 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - requestType, provider, model := bifrost.GetResponseFields(result, bifrostErr) + requestType, provider, model, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) streamType := StreamTypeChat if requestType == schemas.TextCompletionStreamRequest { @@ -495,9 +496,12 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, chunk.TokenUsage = result.TextCompletionResponse.Usage } chunk.ChunkIndex = result.TextCompletionResponse.ExtraFields.ChunkIndex + if result.TextCompletionResponse.ExtraFields.RawResponse != nil { + chunk.RawResponse = bifrost.Ptr(fmt.Sprintf("%v", result.TextCompletionResponse.ExtraFields.RawResponse)) + } if isFinalChunk { if a.pricingManager != nil { - cost := a.pricingManager.CalculateCost(result) + cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug @@ -523,7 +527,7 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, } if isFinalChunk { if a.pricingManager != nil { - cost := a.pricingManager.CalculateCost(result) + cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug @@ -560,7 +564,8 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, RequestID: requestID, StreamType: streamType, Provider: provider, - Model: model, + RequestedModel: model, + ResolvedModel: resolvedModel, Data: data, RawRequest: &rawRequest, }, nil @@ -568,10 +573,11 @@ func (a *Accumulator) processChatStreamingResponse(ctx *schemas.BifrostContext, // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: streamType, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: streamType, + Provider: provider, + RequestedModel: model, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/images.go b/framework/streaming/images.go index 23b2dd8f5c..367b52c037 100644 --- a/framework/streaming/images.go +++ b/framework/streaming/images.go @@ -8,6 +8,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" ) // buildCompleteImageFromImageStreamChunks builds a complete image generation response from accumulated chunks @@ -19,7 +20,7 @@ func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStr finalResponse := &schemas.BifrostImageGenerationResponse{ ID: chunks[i].Delta.ID, Created: chunks[i].Delta.CreatedAt, - Model: chunks[i].Delta.ExtraFields.ModelRequested, + Model: chunks[i].Delta.ExtraFields.OriginalModelRequested, Data: []schemas.ImageData{ { B64JSON: chunks[i].Delta.B64JSON, @@ -52,8 +53,8 @@ func (a *Accumulator) buildCompleteImageFromImageStreamChunks(chunks []*ImageStr } // Extract metadata - if model == "" && chunk.Delta.ExtraFields.ModelRequested != "" { - model = chunk.Delta.ExtraFields.ModelRequested + if model == "" && chunk.Delta.ExtraFields.OriginalModelRequested != "" { + model = chunk.Delta.ExtraFields.OriginalModelRequested } // Store revised prompt if present (usually in first chunk) @@ -215,7 +216,7 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) chunk := a.getImageStreamChunk() @@ -273,7 +274,7 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, if isFinalChunk { if a.pricingManager != nil { - cost := a.pricingManager.CalculateCost(result) + cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug @@ -309,12 +310,13 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, rawRequest = result.ImageGenerationStreamResponse.ExtraFields.RawRequest } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeImage, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeImage, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: data, + RawRequest: &rawRequest, }, nil } @@ -324,10 +326,11 @@ func (a *Accumulator) processImageStreamingResponse(ctx *schemas.BifrostContext, // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeImage, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeImage, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/responses.go b/framework/streaming/responses.go index 65c7635226..56c461cbcd 100644 --- a/framework/streaming/responses.go +++ b/framework/streaming/responses.go @@ -8,6 +8,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" ) // deepCopyResponsesStreamResponse creates a deep copy of BifrostResponsesStreamResponse @@ -889,7 +890,7 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) chunk := a.getResponsesStreamChunk() @@ -912,7 +913,7 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont chunk.ChunkIndex = result.ResponsesStreamResponse.ExtraFields.ChunkIndex if isFinalChunk { if a.pricingManager != nil { - cost := a.pricingManager.CalculateCost(result) + cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug @@ -948,20 +949,22 @@ func (a *Accumulator) processResponsesStreamingResponse(ctx *schemas.BifrostCont } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeResponses, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeResponses, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: data, + RawRequest: &rawRequest, }, nil } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeResponses, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeResponses, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/transcription.go b/framework/streaming/transcription.go index 593c7f80b2..3367e25ad6 100644 --- a/framework/streaming/transcription.go +++ b/framework/streaming/transcription.go @@ -8,6 +8,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/modelcatalog" ) // buildCompleteMessageFromTranscriptionStreamChunks builds a complete message from accumulated transcription chunks @@ -130,7 +131,7 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost // Log error but don't fail the request return nil, fmt.Errorf("accumulator-id not found in context or is empty") } - _, provider, model := bifrost.GetResponseFields(result, bifrostErr) + _, provider, requestedModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) // For audio, all the data comes in the final chunk chunk := a.getTranscriptionStreamChunk() @@ -162,7 +163,7 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost } if isFinalChunk { if a.pricingManager != nil { - cost := a.pricingManager.CalculateCost(result) + cost := a.pricingManager.CalculateCost(result, modelcatalog.PricingLookupScopesFromContext(ctx, string(result.GetExtraFields().Provider))) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.GetExtraFields().CacheDebug @@ -193,21 +194,23 @@ func (a *Accumulator) processTranscriptionStreamingResponse(ctx *schemas.Bifrost rawRequest = result.TranscriptionStreamResponse.ExtraFields.RawRequest } return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeTranscription, - Provider: provider, - Model: model, - Data: data, - RawRequest: &rawRequest, + RequestID: requestID, + StreamType: StreamTypeTranscription, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: data, + RawRequest: &rawRequest, }, nil } // Non-final chunk: skip expensive rebuild since no consumer uses intermediate data. // Both logging and maxim plugins return early when !isFinalChunk. return &ProcessedStreamResponse{ - RequestID: requestID, - StreamType: StreamTypeTranscription, - Provider: provider, - Model: model, - Data: nil, + RequestID: requestID, + StreamType: StreamTypeTranscription, + Provider: provider, + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + Data: nil, }, nil } diff --git a/framework/streaming/types.go b/framework/streaming/types.go index 9d7cf0183f..eb9d10e3ff 100644 --- a/framework/streaming/types.go +++ b/framework/streaming/types.go @@ -228,12 +228,13 @@ func (sa *StreamAccumulator) getLastAudioChunkLocked() *AudioStreamChunk { // ProcessedStreamResponse represents a processed streaming response type ProcessedStreamResponse struct { - RequestID string - StreamType StreamType - Provider schemas.ModelProvider - Model string - Data *AccumulatedData - RawRequest *interface{} + RequestID string + StreamType StreamType + Provider schemas.ModelProvider + RequestedModel string // original model requested by the caller + ResolvedModel string // actual model used by the provider (equals RequestedModel when no alias mapping exists) + Data *AccumulatedData + RawRequest *interface{} } // ToBifrostResponse converts a ProcessedStreamResponse to a BifrostResponse @@ -253,7 +254,7 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { textResp := &schemas.BifrostTextCompletionResponse{ ID: p.RequestID, Object: "text_completion", - Model: p.Model, + Model: p.RequestedModel, Choices: []schemas.BifrostResponseChoice{ { Index: 0, @@ -269,10 +270,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { resp.TextCompletionResponse = textResp resp.TextCompletionResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.TextCompletionRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.TextCompletionResponse.ExtraFields.RawRequest = p.RawRequest @@ -297,7 +299,7 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { chatResp := &schemas.BifrostChatResponse{ ID: p.RequestID, Object: "chat.completion", - Model: p.Model, + Model: p.RequestedModel, Created: int(p.Data.StartTimestamp.Unix()), Choices: []schemas.BifrostResponseChoice{ { @@ -314,10 +316,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { resp.ChatResponse = chatResp resp.ChatResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.ChatCompletionRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.ChatResponse.ExtraFields.RawRequest = p.RawRequest @@ -338,10 +341,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { responsesResp.Usage = p.Data.TokenUsage.ToResponsesResponseUsage() } responsesResp.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.ResponsesRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { responsesResp.ExtraFields.RawRequest = p.RawRequest @@ -360,10 +364,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { } resp.SpeechResponse = speechResp resp.SpeechResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.SpeechRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.SpeechResponse.ExtraFields.RawRequest = p.RawRequest @@ -381,14 +386,21 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { } resp.TranscriptionResponse = transcriptionResp resp.TranscriptionResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.TranscriptionRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.TranscriptionResponse.ExtraFields.RawRequest = p.RawRequest } + if p.Data.RawResponse != nil { + resp.TranscriptionResponse.ExtraFields.RawResponse = *p.Data.RawResponse + } + if p.Data.CacheDebug != nil { + resp.TranscriptionResponse.ExtraFields.CacheDebug = p.Data.CacheDebug + } case StreamTypeImage: imageResp := p.Data.ImageGenerationOutput if imageResp == nil { @@ -398,8 +410,8 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { if p.RequestID != "" { imageResp.ID = p.RequestID } - if p.Model != "" { - imageResp.Model = p.Model + if p.RequestedModel != "" { + imageResp.Model = p.RequestedModel } } // Ensure Data is never nil to serialize as [] instead of null @@ -408,10 +420,11 @@ func (p *ProcessedStreamResponse) ToBifrostResponse() *schemas.BifrostResponse { } resp.ImageGenerationResponse = imageResp resp.ImageGenerationResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: p.Provider, - ModelRequested: p.Model, - Latency: p.Data.Latency, + RequestType: schemas.ImageGenerationRequest, + Provider: p.Provider, + OriginalModelRequested: p.RequestedModel, + ResolvedModelUsed: p.ResolvedModel, + Latency: p.Data.Latency, } if p.RawRequest != nil { resp.ImageGenerationResponse.ExtraFields.RawRequest = p.RawRequest diff --git a/framework/tracing/store.go b/framework/tracing/store.go index ab5fd88b2e..340ebf3c4c 100644 --- a/framework/tracing/store.go +++ b/framework/tracing/store.go @@ -73,7 +73,7 @@ func NewTraceStore(ttl time.Duration, logger schemas.Logger) *TraceStore { // If empty, a new trace ID will be generated. // Note: The parent span ID (for linking to upstream spans) is handled separately // via context in StartSpan, not stored on the trace itself. -func (s *TraceStore) CreateTrace(inheritedTraceID string) string { +func (s *TraceStore) CreateTrace(inheritedTraceID string, requestID ...string) string { trace := s.tracePool.Get().(*schemas.Trace) // Reset and initialize the trace if inheritedTraceID != "" { @@ -85,6 +85,9 @@ func (s *TraceStore) CreateTrace(inheritedTraceID string) string { // Parent-child relationships are between spans, not traces. // The root span's ParentID is set in StartSpan from context. trace.ParentID = "" + if len(requestID) > 0 { + trace.RequestID = requestID[0] + } trace.StartTime = time.Now() trace.EndTime = time.Time{} trace.RootSpan = nil @@ -115,6 +118,15 @@ func (s *TraceStore) GetTrace(traceID string) *schemas.Trace { return nil } +// SetRequestID sets the request ID for the trace +func (s *TraceStore) SetRequestID(traceID string, requestID string) { + trace := s.GetTrace(traceID) + if trace == nil { + return + } + trace.SetRequestID(requestID) +} + // CompleteTrace marks the trace as complete, removes it from store, and returns it for flushing func (s *TraceStore) CompleteTrace(traceID string) *schemas.Trace { // Clear any deferred span for this trace diff --git a/framework/tracing/tracer.go b/framework/tracing/tracer.go index 5c9e95f15a..55fd8873d3 100644 --- a/framework/tracing/tracer.go +++ b/framework/tracing/tracer.go @@ -3,6 +3,9 @@ package tracing import ( "context" + "strings" + "sync" + "sync/atomic" "time" "github.com/maximhq/bifrost/core/schemas" @@ -18,6 +21,9 @@ type Tracer struct { store *TraceStore accumulator *streaming.Accumulator pricingManager *modelcatalog.ModelCatalog + logger schemas.Logger + obsPlugins atomic.Pointer[[]schemas.ObservabilityPlugin] + flushWG sync.WaitGroup } // NewTracer creates a new Tracer wrapping the given TraceStore. @@ -28,12 +34,22 @@ func NewTracer(store *TraceStore, pricingManager *modelcatalog.ModelCatalog, log store: store, accumulator: streaming.NewAccumulator(pricingManager, logger), pricingManager: pricingManager, + logger: logger, + obsPlugins: atomic.Pointer[[]schemas.ObservabilityPlugin]{}, } } +// SetObservabilityPlugins updates the plugins that receive completed traces. +func (t *Tracer) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { + if t == nil { + return + } + t.obsPlugins.Store(&obsPlugins) +} + // CreateTrace creates a new trace with optional parent ID and returns the trace ID. -func (t *Tracer) CreateTrace(parentID string) string { - return t.store.CreateTrace(parentID) +func (t *Tracer) CreateTrace(parentID string, requestID ...string) string { + return t.store.CreateTrace(parentID, requestID...) } // EndTrace completes a trace and returns the trace data for observation/export. @@ -164,7 +180,7 @@ func (t *Tracer) PopulateLLMRequestAttributes(handle schemas.SpanHandle, req *sc } // PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. -func (t *Tracer) PopulateLLMResponseAttributes(handle schemas.SpanHandle, resp *schemas.BifrostResponse, err *schemas.BifrostError) { +func (t *Tracer) PopulateLLMResponseAttributes(ctx *schemas.BifrostContext, handle schemas.SpanHandle, resp *schemas.BifrostResponse, err *schemas.BifrostError) { h, ok := handle.(*spanHandle) if !ok || h == nil { return @@ -185,7 +201,7 @@ func (t *Tracer) PopulateLLMResponseAttributes(handle schemas.SpanHandle, resp * } // Populate cost attribute using pricing manager if t.pricingManager != nil && resp != nil { - cost := t.pricingManager.CalculateCost(resp) + cost := t.pricingManager.CalculateCost(resp, modelcatalog.PricingLookupScopesFromContext(ctx, string(resp.GetExtraFields().Provider))) span.SetAttribute(schemas.AttrUsageCost, cost) } } @@ -306,9 +322,10 @@ func (t *Tracer) ProcessStreamingChunk(traceID string, isFinalChunk bool, result // Convert ProcessedStreamResponse to StreamAccumulatorResult accResult := &schemas.StreamAccumulatorResult{ - RequestID: processedResp.RequestID, - Model: processedResp.Model, - Provider: processedResp.Provider, + RequestID: processedResp.RequestID, + RequestedModel: processedResp.RequestedModel, + ResolvedModel: processedResp.ResolvedModel, + Provider: processedResp.Provider, } if processedResp.Data != nil { @@ -344,9 +361,22 @@ func (t *Tracer) GetAccumulator() *streaming.Accumulator { return t.accumulator } +// AttachPluginLogs appends plugin log entries to the trace identified by traceID. +func (t *Tracer) AttachPluginLogs(traceID string, logs []schemas.PluginLogEntry) { + if len(logs) == 0 || traceID == "" { + return + } + trace := t.store.GetTrace(traceID) + if trace == nil { + return + } + trace.AppendPluginLogs(logs) +} + // Stop stops the tracer and releases its resources. // This stops the internal TraceStore's cleanup goroutine. func (t *Tracer) Stop() { + t.flushWG.Wait() if t.store != nil { t.store.Stop() } @@ -355,5 +385,56 @@ func (t *Tracer) Stop() { } } +// CompleteAndFlushTrace ends a trace and forwards it to any observability +// plugins asynchronously. Realtime transports need this explicit flush because +// they bypass the HTTP tracing middleware that normally injects completed traces. +func (t *Tracer) CompleteAndFlushTrace(traceID string) { + if t == nil { + return + } + if strings.TrimSpace(traceID) == "" { + return + } + t.flushWG.Go(func() { + completedTrace := t.EndTrace(strings.TrimSpace(traceID)) + if completedTrace == nil { + return + } + // Defer release so the pooled trace is returned even if a plugin panics; + // otherwise an unrecovered panic in this detached goroutine leaks the + // trace object and takes down the whole process. + defer t.ReleaseTrace(completedTrace) + + var obsPlugins []schemas.ObservabilityPlugin + if loaded := t.obsPlugins.Load(); loaded != nil { + obsPlugins = *loaded + } + seen := make(map[string]struct{}, len(obsPlugins)) + for _, plugin := range obsPlugins { + if plugin == nil { + continue + } + // Isolate each plugin callback β€” one bad observability backend should + // not crash the server or prevent other plugins from receiving the trace. + func(plugin schemas.ObservabilityPlugin) { + name := "" + defer func() { + if r := recover(); r != nil && t.logger != nil { + t.logger.Error("observability plugin %s panicked during trace injection: %v", name, r) + } + }() + name = plugin.GetName() + if _, exists := seen[name]; exists { + return + } + seen[name] = struct{}{} + if err := plugin.Inject(context.Background(), completedTrace); err != nil && t.logger != nil { + t.logger.Warn("observability plugin %s failed to inject trace: %v", name, err) + } + }(plugin) + } + }) +} + // Ensure Tracer implements schemas.Tracer at compile time var _ schemas.Tracer = (*Tracer)(nil) diff --git a/framework/tracing/tracer_test.go b/framework/tracing/tracer_test.go index 372e075829..33134c67d2 100644 --- a/framework/tracing/tracer_test.go +++ b/framework/tracing/tracer_test.go @@ -8,6 +8,57 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +type testRealtimeObservabilityPlugin struct { + injected chan *schemas.Trace +} + +func (p *testRealtimeObservabilityPlugin) GetName() string { return "test-observability" } +func (p *testRealtimeObservabilityPlugin) Cleanup() error { return nil } +func (p *testRealtimeObservabilityPlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + return req, nil, nil +} +func (p *testRealtimeObservabilityPlugin) PostLLMHook(_ *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} +func (p *testRealtimeObservabilityPlugin) Inject(_ context.Context, trace *schemas.Trace) error { + if trace == nil { + p.injected <- nil + return nil + } + traceCopy := *trace + p.injected <- &traceCopy + return nil +} + +func TestTracer_CompleteAndFlushTraceInjectsObservabilityPlugins(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + traceID := tracer.CreateTrace("") + plugin := &testRealtimeObservabilityPlugin{ + injected: make(chan *schemas.Trace, 1), + } + + tracer.SetObservabilityPlugins([]schemas.ObservabilityPlugin{plugin}) + tracer.CompleteAndFlushTrace(traceID) + + select { + case trace := <-plugin.injected: + if trace == nil || trace.TraceID != traceID { + t.Fatalf("injected trace = %+v, want trace %q", trace, traceID) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for observability inject") + } + + if got := tracer.store.GetTrace(traceID); got != nil { + t.Fatalf("trace %q was not released after flush", traceID) + } +} + func TestTracer_StartSpan_RootSpanWithW3CParent(t *testing.T) { // This is the key test: verifies that when an incoming request has a W3C traceparent header, // the root span in Bifrost correctly links to the upstream service's span. diff --git a/framework/vectorstore/redis.go b/framework/vectorstore/redis.go index a9e4c5f94e..847cad59b1 100644 --- a/framework/vectorstore/redis.go +++ b/framework/vectorstore/redis.go @@ -1610,5 +1610,9 @@ func newRedisStore(_ context.Context, config RedisConfig, logger schemas.Logger) logger: logger, namespaceFieldTypes: make(map[string]map[string]VectorStorePropertyType), } + // Eagerly verify connectivity, consistent with other store constructors (e.g. Qdrant) + if err := store.Ping(context.Background()); err != nil { + return nil, fmt.Errorf("failed to connect to redis: %w", err) + } return store, nil } diff --git a/framework/version b/framework/version index 84e3118775..6261a05bb0 100644 --- a/framework/version +++ b/framework/version @@ -1 +1 @@ -1.2.36 \ No newline at end of file +1.3.1 \ No newline at end of file diff --git a/helm-charts/bifrost/Chart.yaml b/helm-charts/bifrost/Chart.yaml index b4613d3d66..897c4db319 100644 --- a/helm-charts/bifrost/Chart.yaml +++ b/helm-charts/bifrost/Chart.yaml @@ -16,5 +16,4 @@ sources: maintainers: - name: Bifrost Team email: support@getbifrost.ai -icon: https://www.getbifrost.ai/favicon.png - +icon: https://www.getbifrost.ai/favicon.png \ No newline at end of file diff --git a/helm-charts/bifrost/README.md b/helm-charts/bifrost/README.md index 86b14a1357..805b0a7058 100644 --- a/helm-charts/bifrost/README.md +++ b/helm-charts/bifrost/README.md @@ -14,8 +14,17 @@ Official Helm charts for deploying [Bifrost](https://github.com/maximhq/bifrost) ### v2.0.15 -- Added `whitelistedRoutes` client config property for routes that bypass auth middleware -- Added `whitelistedRoutes` to Helm schema, values, and template rendering +- Synced helm schema with transport `config.schema.json` β€” added missing properties: + - `client.mcpDisableAutoToolInject` β€” disable automatic MCP tool injection + - `governance.budgets[].calendar_aligned` β€” snap budget resets to calendar boundaries + - `governance.pricingOverrides` β€” scoped pricing overrides for the model catalog + - `mcp.clientConfigs[].allowedExtraHeaders` β€” header allowlist per MCP client + - `mcp.clientConfigs[].allowOnAllVirtualKeys` β€” make MCP server accessible to all virtual keys + - `mcp.toolManagerConfig.disableAutoToolInject` β€” disable auto tool injection at manager level + - `networkConfig.beta_header_overrides` β€” override Anthropic beta header support per provider + - `websocket` β€” full WebSocket gateway tuning (connections, pool, transcript buffer) +- Fixed SSE `connectionString` not being rendered in `_helpers.tpl` for MCP clients +- Added template rendering for all new properties in `_helpers.tpl` ### v2.0.14 diff --git a/helm-charts/bifrost/templates/_helpers.tpl b/helm-charts/bifrost/templates/_helpers.tpl index 3dbb1d6574..8dc0606658 100644 --- a/helm-charts/bifrost/templates/_helpers.tpl +++ b/helm-charts/bifrost/templates/_helpers.tpl @@ -284,6 +284,9 @@ false {{- if hasKey .Values.bifrost.client "hideDeletedVirtualKeysInFilters" }} {{- $_ := set $client "hide_deleted_virtual_keys_in_filters" .Values.bifrost.client.hideDeletedVirtualKeysInFilters }} {{- end }} +{{- if hasKey .Values.bifrost.client "mcpDisableAutoToolInject" }} +{{- $_ := set $client "mcp_disable_auto_tool_inject" .Values.bifrost.client.mcpDisableAutoToolInject }} +{{- end }} {{- $_ := set $config "client" $client }} {{- end }} {{- /* Framework */ -}} @@ -357,6 +360,9 @@ false {{- if .Values.bifrost.governance.providers }} {{- $_ := set $governance "providers" .Values.bifrost.governance.providers }} {{- end }} +{{- if .Values.bifrost.governance.pricingOverrides }} +{{- $_ := set $governance "pricing_overrides" .Values.bifrost.governance.pricingOverrides }} +{{- end }} {{- if .Values.bifrost.governance.authConfig }} {{- $authConfig := dict }} {{- if and .Values.bifrost.governance.authConfig.existingSecret .Values.bifrost.governance.authConfig.usernameKey }} @@ -379,7 +385,7 @@ false {{- $_ := set $governance "auth_config" $authConfig }} {{- end }} {{- end }} -{{- if or $governance.budgets $governance.rate_limits $governance.customers $governance.teams $governance.virtual_keys $governance.routing_rules $governance.model_configs $governance.providers $governance.auth_config }} +{{- if or $governance.budgets $governance.rate_limits $governance.customers $governance.teams $governance.virtual_keys $governance.routing_rules $governance.model_configs $governance.providers $governance.pricing_overrides $governance.auth_config }} {{- $_ := set $config "governance" $governance }} {{- end }} {{- end }} @@ -670,6 +676,10 @@ false {{- if and (eq $client.connectionType "websocket") $client.websocketConfig }} {{- $_ := set $cc "connection_string" $client.websocketConfig.url }} {{- end }} +{{- /* Map connectionString for SSE connections */ -}} +{{- if and (eq $client.connectionType "sse") $client.connectionString }} +{{- $_ := set $cc "connection_string" $client.connectionString }} +{{- end }} {{- /* Map stdioConfig -> stdio_config */ -}} {{- if $client.stdioConfig }} {{- $stdio := dict "command" $client.stdioConfig.command }} @@ -712,6 +722,12 @@ false {{- if $client.toolPricing }} {{- $_ := set $cc "tool_pricing" $client.toolPricing }} {{- end }} +{{- if $client.allowedExtraHeaders }} +{{- $_ := set $cc "allowed_extra_headers" $client.allowedExtraHeaders }} +{{- end }} +{{- if hasKey $client "allowOnAllVirtualKeys" }} +{{- $_ := set $cc "allow_on_all_virtual_keys" $client.allowOnAllVirtualKeys }} +{{- end }} {{- $clientConfigs = append $clientConfigs $cc }} {{- end }} {{- $mcpConfig := dict "client_configs" $clientConfigs }} @@ -726,6 +742,9 @@ false {{- if .Values.bifrost.mcp.toolManagerConfig.codeModeBindingLevel }} {{- $_ := set $tmConfig "code_mode_binding_level" .Values.bifrost.mcp.toolManagerConfig.codeModeBindingLevel }} {{- end }} +{{- if hasKey .Values.bifrost.mcp.toolManagerConfig "disableAutoToolInject" }} +{{- $_ := set $tmConfig "disable_auto_tool_inject" .Values.bifrost.mcp.toolManagerConfig.disableAutoToolInject }} +{{- end }} {{- if $tmConfig }} {{- $_ := set $mcpConfig "tool_manager_config" $tmConfig }} {{- end }} @@ -901,6 +920,37 @@ false {{- $_ := set $config "audit_logs" $auditLogs }} {{- end }} {{- end }} +{{- /* WebSocket Config */ -}} +{{- if .Values.bifrost.websocket }} +{{- $ws := dict }} +{{- if .Values.bifrost.websocket.maxConnectionsPerUser }} +{{- $_ := set $ws "max_connections_per_user" .Values.bifrost.websocket.maxConnectionsPerUser }} +{{- end }} +{{- if .Values.bifrost.websocket.transcriptBufferSize }} +{{- $_ := set $ws "transcript_buffer_size" .Values.bifrost.websocket.transcriptBufferSize }} +{{- end }} +{{- if .Values.bifrost.websocket.pool }} +{{- $pool := dict }} +{{- if .Values.bifrost.websocket.pool.maxIdlePerKey }} +{{- $_ := set $pool "max_idle_per_key" .Values.bifrost.websocket.pool.maxIdlePerKey }} +{{- end }} +{{- if .Values.bifrost.websocket.pool.maxTotalConnections }} +{{- $_ := set $pool "max_total_connections" .Values.bifrost.websocket.pool.maxTotalConnections }} +{{- end }} +{{- if .Values.bifrost.websocket.pool.idleTimeoutSeconds }} +{{- $_ := set $pool "idle_timeout_seconds" .Values.bifrost.websocket.pool.idleTimeoutSeconds }} +{{- end }} +{{- if .Values.bifrost.websocket.pool.maxConnectionLifetimeSeconds }} +{{- $_ := set $pool "max_connection_lifetime_seconds" .Values.bifrost.websocket.pool.maxConnectionLifetimeSeconds }} +{{- end }} +{{- if $pool }} +{{- $_ := set $ws "pool" $pool }} +{{- end }} +{{- end }} +{{- if $ws }} +{{- $_ := set $config "websocket" $ws }} +{{- end }} +{{- end }} {{- $config | toJson }} {{- end }} diff --git a/helm-charts/bifrost/values.schema.json b/helm-charts/bifrost/values.schema.json index 0f7274135f..495e9c8e79 100644 --- a/helm-charts/bifrost/values.schema.json +++ b/helm-charts/bifrost/values.schema.json @@ -383,6 +383,10 @@ "hideDeletedVirtualKeysInFilters": { "type": "boolean", "description": "When true, deleted virtual keys are omitted from logs and MCP logs filter data" + }, + "mcpDisableAutoToolInject": { + "type": "boolean", + "description": "When true, MCP tools are not automatically injected into requests. Tools are only included when explicitly specified via request context filters or headers." } }, "additionalProperties": false @@ -456,6 +460,11 @@ "type": "string", "enum": ["server", "tool"], "description": "How tools are exposed in VFS for code execution" + }, + "disableAutoToolInject": { + "type": "boolean", + "description": "When true, MCP tools are not automatically injected into requests. Tools are only included when explicitly specified.", + "default": false } } }, @@ -917,6 +926,11 @@ "last_reset": { "type": "string", "format": "date-time" + }, + "calendar_aligned": { + "type": "boolean", + "description": "Snap resets to calendar boundaries (day/week/month/year start)", + "default": false } }, "required": [ @@ -1220,6 +1234,40 @@ }, "required": ["name"] } + }, + "pricingOverrides": { + "type": "array", + "description": "Scoped pricing overrides applied at runtime by the model catalog", + "items": { + "type": "object", + "properties": { + "id": { "type": "string", "description": "Unique pricing override ID" }, + "name": { "type": "string", "description": "Human-readable name for this override" }, + "scope_kind": { + "type": "string", + "enum": ["global", "provider", "provider_key", "virtual_key", "virtual_key_provider", "virtual_key_provider_key"], + "description": "Scope level for this override" + }, + "virtual_key_id": { "type": "string", "description": "Virtual key ID (required for virtual_key* scopes)" }, + "provider_id": { "type": "string", "description": "Provider ID (required for provider* scopes)" }, + "provider_key_id": { "type": "string", "description": "Provider key ID (required for provider_key and virtual_key_provider_key scopes)" }, + "match_type": { + "type": "string", + "enum": ["exact", "wildcard"], + "description": "How the pattern is matched against model names" + }, + "pattern": { "type": "string", "description": "Model name pattern to match" }, + "request_types": { + "type": "array", + "minItems": 1, + "items": { "type": "string" }, + "description": "Request types this override applies to" + }, + "pricing_patch": { "type": "string", "description": "JSON-encoded pricing fields to override" }, + "config_hash": { "type": "string", "description": "Internal hash for change detection (auto-managed)" } + }, + "required": ["id", "name", "scope_kind", "match_type", "pattern", "request_types"] + } } }, "additionalProperties": false @@ -1606,6 +1654,54 @@ "type": "string" } } + }, + "websocket": { + "type": "object", + "description": "Optional tuning for the WebSocket gateway (Responses API WebSocket Mode, Realtime API)", + "properties": { + "maxConnectionsPerUser": { + "type": "integer", + "minimum": 1, + "description": "Maximum concurrent WebSocket connections per user", + "default": 100 + }, + "transcriptBufferSize": { + "type": "integer", + "minimum": 1, + "description": "Number of transcript entries to buffer for Realtime API mid-session fallback", + "default": 100 + }, + "pool": { + "type": "object", + "description": "Upstream WebSocket connection pool configuration", + "properties": { + "maxIdlePerKey": { + "type": "integer", + "minimum": 1, + "description": "Maximum idle connections per provider/key combination", + "default": 50 + }, + "maxTotalConnections": { + "type": "integer", + "minimum": 1, + "description": "Maximum total idle connections across all providers", + "default": 1000 + }, + "idleTimeoutSeconds": { + "type": "integer", + "minimum": 1, + "description": "Seconds before an idle connection is evicted", + "default": 600 + }, + "maxConnectionLifetimeSeconds": { + "type": "integer", + "minimum": 1, + "description": "Maximum lifetime of a connection in seconds", + "default": 7200 + } + } + } + } } } }, @@ -2562,6 +2658,11 @@ "minimum": 1, "maximum": 10000, "description": "Maximum number of TCP connections per provider host. For HTTP/2 (e.g. Bedrock), each connection supports ~100 concurrent streams. Default: 5000." + }, + "beta_header_overrides": { + "type": "object", + "additionalProperties": { "type": "boolean" }, + "description": "Override default Anthropic beta header support per provider. Keys are header prefixes, values are true (supported) or false (unsupported)." } } }, @@ -2725,6 +2826,16 @@ "type": "number", "minimum": 0 } + }, + "allowedExtraHeaders": { + "type": "array", + "items": { "type": "string" }, + "description": "Allowlist of request-level headers that callers may forward to this MCP server. Use ['*'] to allow all headers." + }, + "allowOnAllVirtualKeys": { + "type": "boolean", + "description": "When true, this MCP server is accessible to all virtual keys without requiring explicit per-key assignment.", + "default": false } }, "required": [ @@ -2813,7 +2924,7 @@ }, "allowed_models": { "type": "array", - "description": "Allowed models for this provider config (empty means all models allowed)", + "description": "Allowed models for this provider config. Use [\"*\"] to allow all models; empty array denies all (deny-by-default).", "items": { "type": "string" } @@ -2853,7 +2964,7 @@ "items": { "type": "string" }, - "description": "Supported models for this key" + "description": "Models this key can access. Use [\"*\"] to allow all models; empty array denies all (deny-by-default)." }, "weight": { "type": "number", @@ -2880,7 +2991,6 @@ "description": "Azure API version" } }, - "required": ["endpoint"], "additionalProperties": false }, "vertex_key_config": { @@ -2910,7 +3020,6 @@ "description": "Model to deployment mappings" } }, - "required": ["project_id", "region"], "additionalProperties": false }, "bedrock_key_config": { @@ -2990,10 +3099,6 @@ "description": "Exact model name served on this VLLM instance" } }, - "required": [ - "url", - "model_name" - ], "additionalProperties": false } }, @@ -3048,8 +3153,7 @@ ] } } - ], - "required": ["key_id", "name", "value"] + ] } } }, diff --git a/plugins/governance/allow_on_all_virtual_keys_test.go b/plugins/governance/allow_on_all_virtual_keys_test.go new file mode 100644 index 0000000000..6fab445c56 --- /dev/null +++ b/plugins/governance/allow_on_all_virtual_keys_test.go @@ -0,0 +1,168 @@ +package governance + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" +) + +// mockInMemoryStore is a test double for InMemoryStore. +type mockInMemoryStore struct { + allowAllClients map[string]string // clientID β†’ clientName + configuredProviders map[schemas.ModelProvider]configstore.ProviderConfig +} + +func (m *mockInMemoryStore) GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig { + return m.configuredProviders +} + +func (m *mockInMemoryStore) GetMCPClientsAllowingAllVirtualKeys() map[string]string { + return m.allowAllClients +} + +// newPluginWithInMemoryStore builds a minimal GovernancePlugin wired with a mock InMemoryStore. +func newPluginWithInMemoryStore(store InMemoryStore) *GovernancePlugin { + return &GovernancePlugin{inMemoryStore: store} +} + +// buildVKWithMCPConfigs returns a VK that has explicit MCPConfigs for the given client. +func buildVKWithMCPConfigs(clientID, clientName string, tools []string) *configstoreTables.TableVirtualKey { + return &configstoreTables.TableVirtualKey{ + ID: "vk-1", + Name: "test-vk", + MCPConfigs: []configstoreTables.TableVirtualKeyMCPConfig{ + { + MCPClient: configstoreTables.TableMCPClient{ + ClientID: clientID, + Name: clientName, + }, + ToolsToExecute: tools, + }, + }, + } +} + +// buildVKNoMCPConfigs returns a VK with no MCPConfigs at all. +func buildVKNoMCPConfigs() *configstoreTables.TableVirtualKey { + return &configstoreTables.TableVirtualKey{ + ID: "vk-2", + Name: "test-vk-empty", + } +} + +// ============================================================================ +// isMCPToolAllowedByVKWith β€” AllowOnAllVirtualKeys scenarios +// ============================================================================ + +// VK with no MCPConfigs + AllowOnAllVirtualKeys client β†’ tools allowed +func TestIsMCPToolAllowedByVKWith_NoVKConfig_AllowAllEnabled(t *testing.T) { + p := newPluginWithInMemoryStore(&mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + }) + vk := buildVKNoMCPConfigs() + + assert.True(t, p.isMCPToolAllowedByVKWith(vk, "youtube-search", map[string]string{"client-1": "youtube"}), + "specific tool should be allowed when AllowOnAllVirtualKeys is set and VK has no explicit config") + + assert.True(t, p.isMCPToolAllowedByVKWith(vk, "youtube-*", map[string]string{"client-1": "youtube"}), + "wildcard pattern should be allowed when AllowOnAllVirtualKeys is set and VK has no explicit config") +} + +// VK with explicit empty tools config for an AllowOnAllVirtualKeys client β†’ tools blocked +func TestIsMCPToolAllowedByVKWith_ExplicitEmptyConfig_Blocks(t *testing.T) { + p := newPluginWithInMemoryStore(&mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + }) + // Explicit VK config with empty tools list (deny-all for this client) + vk := buildVKWithMCPConfigs("client-1", "youtube", []string{}) + + assert.False(t, p.isMCPToolAllowedByVKWith(vk, "youtube-search", map[string]string{"client-1": "youtube"}), + "explicit empty tools list should block access even when AllowOnAllVirtualKeys is set") + + assert.False(t, p.isMCPToolAllowedByVKWith(vk, "youtube-*", map[string]string{"client-1": "youtube"}), + "wildcard should be blocked when explicit config has empty tools list") +} + +// VK with explicit ["tool1"] config for an AllowOnAllVirtualKeys client β†’ only tool1 allowed +func TestIsMCPToolAllowedByVKWith_ExplicitPartialConfig_OnlyListedToolsAllowed(t *testing.T) { + p := newPluginWithInMemoryStore(&mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + }) + vk := buildVKWithMCPConfigs("client-1", "youtube", []string{"search"}) + + assert.True(t, p.isMCPToolAllowedByVKWith(vk, "youtube-search", map[string]string{"client-1": "youtube"}), + "explicitly listed tool should be allowed") + + assert.False(t, p.isMCPToolAllowedByVKWith(vk, "youtube-upload", map[string]string{"client-1": "youtube"}), + "non-listed tool should be blocked even when AllowOnAllVirtualKeys is set") +} + +// inMemoryStore is nil β†’ AllowOnAllVirtualKeys clients are treated as not configured (all blocked) +func TestIsMCPToolAllowedByVKWith_NilInMemoryStore_AllBlocked(t *testing.T) { + p := &GovernancePlugin{inMemoryStore: nil} + vk := buildVKNoMCPConfigs() + + allowed := p.isMCPToolAllowedByVKWith(vk, "youtube-search", nil) + assert.False(t, allowed, + "nil inMemoryStore means no AllowOnAllVirtualKeys clients; tool should be blocked") +} + +// Wildcard pattern (clientName-*) with AllowOnAllVirtualKeys client and no VK config β†’ allowed +func TestIsMCPToolAllowedByVKWith_WildcardPattern_AllowAll_NoVKConfig(t *testing.T) { + p := newPluginWithInMemoryStore(&mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + }) + vk := buildVKNoMCPConfigs() + + assert.True(t, p.isMCPToolAllowedByVKWith(vk, "youtube-*", map[string]string{"client-1": "youtube"}), + "clientName-* wildcard should match AllowOnAllVirtualKeys fallback") +} + +// Explicit unrestricted config (["*"]) for AllowOnAllVirtualKeys client β†’ all tools allowed +func TestIsMCPToolAllowedByVKWith_ExplicitUnrestrictedConfig_AllowsAll(t *testing.T) { + p := newPluginWithInMemoryStore(&mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + }) + vk := buildVKWithMCPConfigs("client-1", "youtube", []string{"*"}) + + assert.True(t, p.isMCPToolAllowedByVKWith(vk, "youtube-search", map[string]string{"client-1": "youtube"}), + "unrestricted explicit config should allow all tools") + + assert.True(t, p.isMCPToolAllowedByVKWith(vk, "youtube-*", map[string]string{"client-1": "youtube"}), + "wildcard should match when explicit config is unrestricted") +} + +// Tool belonging to a different client is not allowed via AllowOnAllVirtualKeys of another client +func TestIsMCPToolAllowedByVKWith_DifferentClient_Blocked(t *testing.T) { + p := newPluginWithInMemoryStore(&mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + }) + vk := buildVKNoMCPConfigs() + + assert.False(t, p.isMCPToolAllowedByVKWith(vk, "github-list_repos", map[string]string{"client-1": "youtube"}), + "tool from a different client should not be allowed via another client's AllowOnAllVirtualKeys") +} + +// isMCPToolAllowedByVK delegates to inMemoryStore correctly +func TestIsMCPToolAllowedByVK_UsesInMemoryStore(t *testing.T) { + store := &mockInMemoryStore{ + allowAllClients: map[string]string{"client-1": "youtube"}, + } + p := newPluginWithInMemoryStore(store) + vk := buildVKNoMCPConfigs() + + assert.True(t, p.isMCPToolAllowedByVK(vk, "youtube-search"), + "isMCPToolAllowedByVK should use inMemoryStore to resolve AllowOnAllVirtualKeys") +} + +// isMCPToolAllowedByVK with nil inMemoryStore β†’ blocked +func TestIsMCPToolAllowedByVK_NilStore_Blocked(t *testing.T) { + p := &GovernancePlugin{inMemoryStore: nil} + vk := buildVKNoMCPConfigs() + + assert.False(t, p.isMCPToolAllowedByVK(vk, "youtube-search"), + "nil inMemoryStore should result in blocked access") +} diff --git a/plugins/governance/changelog.md b/plugins/governance/changelog.md index e69de29bb2..657710e664 100644 --- a/plugins/governance/changelog.md +++ b/plugins/governance/changelog.md @@ -0,0 +1,5 @@ +- feat: add realtime WebSocket, WebRTC, and client secret handlers +- feat: add access profiles for fine-grained permission control +- feat: add support for tracking userId, teamId, customerId, and businessUnitId +- fix: SQLite migration connections and error handling + vk not found message +- fix: preserve routing rule targets for genai and bedrock paths diff --git a/plugins/governance/go.mod b/plugins/governance/go.mod index c4cd4bc1c8..57c169d378 100644 --- a/plugins/governance/go.mod +++ b/plugins/governance/go.mod @@ -8,8 +8,8 @@ require ( github.com/bytedance/sonic v1.15.0 github.com/google/cel-go v0.26.1 github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 github.com/stretchr/testify v1.11.1 github.com/valyala/fasthttp v1.68.0 ) diff --git a/plugins/governance/go.sum b/plugins/governance/go.sum index 9c248728d7..87092e1470 100644 --- a/plugins/governance/go.sum +++ b/plugins/governance/go.sum @@ -199,10 +199,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/plugins/governance/http_transport_prehook_test.go b/plugins/governance/http_transport_prehook_test.go index c79ca20b15..295b19f0b9 100644 --- a/plugins/governance/http_transport_prehook_test.go +++ b/plugins/governance/http_transport_prehook_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "testing" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" @@ -23,14 +24,14 @@ func TestHTTPTransportPreHook_VirtualKeyReplicateRefinesNestedModel(t *testing.T Data: []schemas.Model{ {ID: "replicate/openai/gpt-5-nano"}, }, - }, nil, nil) + }, nil) virtualKey := buildVirtualKeyWithProviders( "vk1", "sk-bf-test", "replicate-only", []configstoreTables.TableVirtualKeyProviderConfig{ - buildProviderConfig("replicate", nil), + buildProviderConfig("replicate", []string{"*"}), }, ) store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ @@ -63,3 +64,306 @@ func TestHTTPTransportPreHook_VirtualKeyReplicateRefinesNestedModel(t *testing.T require.NoError(t, json.Unmarshal(req.Body, &payload)) require.Equal(t, "replicate/openai/gpt-5-nano", payload.Model) } + +// TestHTTPTransportPreHook_GenAIRoutingRulePreservesTarget verifies that when a routing rule +// matches on the /genai path, governance load balancing does not override the routing-rule target +// with a provider from the VK pool (regression test for issue #2516). +func TestHTTPTransportPreHook_GenAIRoutingRulePreservesTarget(t *testing.T) { + logger := NewMockLogger() + + routingRule := configstoreTables.TableRoutingRule{ + ID: "rule-genai-1", + Name: "genai-repro-rule", + Enabled: true, + CelExpression: `model == "probe-genai-model" && provider == ""`, + Targets: []configstoreTables.TableRoutingTarget{ + { + RuleID: "rule-genai-1", + Provider: bifrost.Ptr("repro-openai-a"), + Model: bifrost.Ptr("error-test"), + Weight: 1.0, + }, + }, + Scope: "global", + Priority: 1, + } + + // VK with repro-openai-b at weight=1 β€” this is what governance LB would wrongly select without the fix + virtualKey := buildVirtualKeyWithProviders( + "vk-genai", + "sk-bf-genai-test", + "genai-repro-vk", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("repro-openai-b", []string{"*"}), + }, + ) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*virtualKey}, + RoutingRules: []configstoreTables.TableRoutingRule{routingRule}, + }, nil) + require.NoError(t, err) + + plugin, err := InitFromStore(context.Background(), &Config{IsVkMandatory: boolPtr(false)}, logger, store, nil, nil, nil, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, plugin.Cleanup()) + }() + + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = "POST" + req.Path = "/genai/v1beta/models/probe-genai-model:generateContent" + req.PathParams["model"] = "probe-genai-model:generateContent" + req.Headers["Authorization"] = "Bearer sk-bf-genai-test" + req.Headers["Content-Type"] = "application/json" + req.Body = []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + resp, err := plugin.HTTPTransportPreHook(bfCtx, req) + require.NoError(t, err) + require.Nil(t, resp) + + // Routing rule matched and set context model to "repro-openai-a/error-test:generateContent". + // Governance LB must NOT override this with "repro-openai-b/probe-genai-model:generateContent". + ctxModel, ok := bfCtx.Value("model").(string) + require.True(t, ok, "context model should be set") + require.Equal(t, "repro-openai-a/error-test:generateContent", ctxModel) +} + +// TestHTTPTransportPreHook_GenAIRoutingRulePreservesTarget_WithStore is a production-like variant +// of TestHTTPTransportPreHook_GenAIRoutingRulePreservesTarget that passes a non-nil inMemoryStore +// containing the routing-rule provider, confirming the fix holds when p.inMemoryStore != nil +// and the provider IS present in GetConfiguredProviders (the normal production code path). +func TestHTTPTransportPreHook_GenAIRoutingRulePreservesTarget_WithStore(t *testing.T) { + logger := NewMockLogger() + + routingRule := configstoreTables.TableRoutingRule{ + ID: "rule-genai-ws-1", + Name: "genai-repro-rule-with-store", + Enabled: true, + CelExpression: `model == "probe-genai-model" && provider == ""`, + Targets: []configstoreTables.TableRoutingTarget{ + { + RuleID: "rule-genai-ws-1", + Provider: bifrost.Ptr("repro-openai-a"), + Model: bifrost.Ptr("error-test"), + Weight: 1.0, + }, + }, + Scope: "global", + Priority: 1, + } + + virtualKey := buildVirtualKeyWithProviders( + "vk-genai-ws", + "sk-bf-genai-ws-test", + "genai-repro-vk-with-store", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("repro-openai-b", []string{"*"}), + }, + ) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*virtualKey}, + RoutingRules: []configstoreTables.TableRoutingRule{routingRule}, + }, nil) + require.NoError(t, err) + + // Register the fake provider so ParseModelString can split "repro-openai-a/model" + // the same way it would for a real provider in production. + schemas.RegisterKnownProvider("repro-openai-a") + t.Cleanup(func() { schemas.UnregisterKnownProvider("repro-openai-a") }) + + // Use a non-nil inMemoryStore that recognises the routing-rule provider, + // mirroring production where configured providers are always registered in the store. + inMemStore := &mockInMemoryStore{ + configuredProviders: map[schemas.ModelProvider]configstore.ProviderConfig{ + "repro-openai-a": {}, + }, + } + + plugin, err := InitFromStore(context.Background(), &Config{IsVkMandatory: boolPtr(false)}, logger, store, nil, nil, nil, inMemStore) + require.NoError(t, err) + defer func() { + require.NoError(t, plugin.Cleanup()) + }() + + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = "POST" + req.Path = "/genai/v1beta/models/probe-genai-model:generateContent" + req.PathParams["model"] = "probe-genai-model:generateContent" + req.Headers["Authorization"] = "Bearer sk-bf-genai-ws-test" + req.Headers["Content-Type"] = "application/json" + req.Body = []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + resp, err := plugin.HTTPTransportPreHook(bfCtx, req) + require.NoError(t, err) + require.Nil(t, resp) + + ctxModel, ok := bfCtx.Value("model").(string) + require.True(t, ok, "context model should be set") + require.Equal(t, "repro-openai-a/error-test:generateContent", ctxModel) +} + +// TestHTTPTransportPreHook_GenAINoRoutingRuleStillLoadBalances verifies that when no routing rule +// matches on the /genai path, governance load balancing still selects a provider from the VK pool. +func TestHTTPTransportPreHook_GenAINoRoutingRuleStillLoadBalances(t *testing.T) { + logger := NewMockLogger() + + // VK with repro-openai-b at weight=1 β€” LB should select this + virtualKey := buildVirtualKeyWithProviders( + "vk-genai-lb", + "sk-bf-genai-lb-test", + "genai-lb-vk", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("repro-openai-b", []string{"*"}), + }, + ) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*virtualKey}, + // No routing rules β€” governance LB should run normally + }, nil) + require.NoError(t, err) + + plugin, err := InitFromStore(context.Background(), &Config{IsVkMandatory: boolPtr(false)}, logger, store, nil, nil, nil, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, plugin.Cleanup()) + }() + + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = "POST" + req.Path = "/genai/v1beta/models/probe-genai-model:generateContent" + req.PathParams["model"] = "probe-genai-model:generateContent" + req.Headers["Authorization"] = "Bearer sk-bf-genai-lb-test" + req.Headers["Content-Type"] = "application/json" + req.Body = []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + resp, err := plugin.HTTPTransportPreHook(bfCtx, req) + require.NoError(t, err) + require.Nil(t, resp) + + // No routing rule: governance LB must still run and select repro-openai-b from the VK pool + ctxModel, ok := bfCtx.Value("model").(string) + require.True(t, ok, "context model should be set by governance LB") + require.Equal(t, "repro-openai-b/probe-genai-model:generateContent", ctxModel) +} + +// TestHTTPTransportPreHook_BedrockRoutingRulePreservesTarget verifies that when a routing rule +// matches on the /bedrock path, governance load balancing does not override the routing-rule target +// (regression test mirroring the GenAI fix for the Bedrock integration). +func TestHTTPTransportPreHook_BedrockRoutingRulePreservesTarget(t *testing.T) { + logger := NewMockLogger() + + routingRule := configstoreTables.TableRoutingRule{ + ID: "rule-bedrock-1", + Name: "bedrock-repro-rule", + Enabled: true, + CelExpression: `model == "probe-bedrock-model" && provider == ""`, + Targets: []configstoreTables.TableRoutingTarget{ + { + RuleID: "rule-bedrock-1", + Provider: bifrost.Ptr("repro-openai-a"), + Model: bifrost.Ptr("error-test"), + Weight: 1.0, + }, + }, + Scope: "global", + Priority: 1, + } + + // VK with repro-openai-b at weight=1 β€” this is what governance LB would wrongly select without the fix + virtualKey := buildVirtualKeyWithProviders( + "vk-bedrock", + "sk-bf-bedrock-test", + "bedrock-repro-vk", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("repro-openai-b", []string{"*"}), + }, + ) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*virtualKey}, + RoutingRules: []configstoreTables.TableRoutingRule{routingRule}, + }, nil) + require.NoError(t, err) + + plugin, err := InitFromStore(context.Background(), &Config{IsVkMandatory: boolPtr(false)}, logger, store, nil, nil, nil, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, plugin.Cleanup()) + }() + + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = "POST" + req.Path = "/bedrock/model/probe-bedrock-model/converse" + req.PathParams["modelId"] = "probe-bedrock-model" + req.Headers["Authorization"] = "Bearer sk-bf-bedrock-test" + req.Headers["Content-Type"] = "application/json" + req.Body = []byte(`{"messages":[{"role":"user","content":[{"text":"hi"}]}]}`) + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + resp, err := plugin.HTTPTransportPreHook(bfCtx, req) + require.NoError(t, err) + require.Nil(t, resp) + + // Routing rule matched and set context modelId to "repro-openai-a/error-test". + // Governance LB must NOT override this with "repro-openai-b/probe-bedrock-model". + ctxModelID, ok := bfCtx.Value("modelId").(string) + require.True(t, ok, "context modelId should be set") + require.Equal(t, "repro-openai-a/error-test", ctxModelID) +} + +// TestHTTPTransportPreHook_BedrockNoRoutingRuleStillLoadBalances verifies that when no routing rule +// matches on the /bedrock path, governance load balancing still selects a provider from the VK pool. +func TestHTTPTransportPreHook_BedrockNoRoutingRuleStillLoadBalances(t *testing.T) { + logger := NewMockLogger() + + // VK with repro-openai-b at weight=1 β€” LB should select this + virtualKey := buildVirtualKeyWithProviders( + "vk-bedrock-lb", + "sk-bf-bedrock-lb-test", + "bedrock-lb-vk", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("repro-openai-b", []string{"*"}), + }, + ) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*virtualKey}, + // No routing rules β€” governance LB should run normally + }, nil) + require.NoError(t, err) + + plugin, err := InitFromStore(context.Background(), &Config{IsVkMandatory: boolPtr(false)}, logger, store, nil, nil, nil, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, plugin.Cleanup()) + }() + + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + req.Method = "POST" + req.Path = "/bedrock/model/probe-bedrock-model/converse" + req.PathParams["modelId"] = "probe-bedrock-model" + req.Headers["Authorization"] = "Bearer sk-bf-bedrock-lb-test" + req.Headers["Content-Type"] = "application/json" + req.Body = []byte(`{"messages":[{"role":"user","content":[{"text":"hi"}]}]}`) + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + resp, err := plugin.HTTPTransportPreHook(bfCtx, req) + require.NoError(t, err) + require.Nil(t, resp) + + // No routing rule: governance LB must still run and select repro-openai-b from the VK pool + ctxModelID, ok := bfCtx.Value("modelId").(string) + require.True(t, ok, "context modelId should be set by governance LB") + require.Equal(t, "repro-openai-b/probe-bedrock-model", ctxModelID) +} diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 0098de2fe5..de02673d6b 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -7,7 +7,6 @@ import ( "fmt" "math/rand/v2" "net/url" - "slices" "sort" "strings" "sync" @@ -29,26 +28,28 @@ import ( const PluginName = "governance" const ( - governanceRejectedContextKey schemas.BifrostContextKey = "bf-governance-rejected" - governanceIsCacheReadContextKey schemas.BifrostContextKey = "bf-governance-is-cache-read" - governanceIsBatchContextKey schemas.BifrostContextKey = "bf-governance-is-batch" + governanceRejectedContextKey schemas.BifrostContextKey = "bf-governance-rejected" VirtualKeyPrefix = "sk-bf-" ) // Config is the configuration for the governance plugin type Config struct { - IsVkMandatory *bool `json:"is_vk_mandatory"` - RequiredHeaders *[]string `json:"required_headers"` // Pointer to live config slice; changes are reflected immediately without restart - IsEnterprise bool `json:"is_enterprise"` + IsVkMandatory *bool `json:"is_vk_mandatory"` + RequiredHeaders *[]string `json:"required_headers"` // Pointer to live config slice; changes are reflected immediately without restart + IsEnterprise bool `json:"is_enterprise"` + DisableAutoToolInject *bool `json:"disable_auto_tool_inject"` + RoutingChainMaxDepth *int `json:"routing_chain_max_depth"` // Pointer to live config value; changes are reflected immediately without restart } type InMemoryStore interface { GetConfiguredProviders() map[schemas.ModelProvider]configstore.ProviderConfig + GetMCPClientsAllowingAllVirtualKeys() map[string]string // clientID β†’ clientName } type BaseGovernancePlugin interface { GetName() string + EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) @@ -83,9 +84,10 @@ type GovernancePlugin struct { cfgMutex sync.RWMutex - isVkMandatory *bool - requiredHeaders *[]string // pointer to live config slice; lowercased at check time - isEnterprise bool + isVkMandatory *bool + requiredHeaders *[]string // pointer to live config slice; lowercased at check time + isEnterprise bool + disableAutoToolInject *bool } // Init initializes and returns a governance plugin instance. @@ -150,9 +152,17 @@ func Init( // Handle nil config - use safe defaults var isVkMandatory *bool var requiredHeaders *[]string + var disableAutoToolInject *bool + var routingChainMaxDepth *int if config != nil { isVkMandatory = config.IsVkMandatory requiredHeaders = config.RequiredHeaders + disableAutoToolInject = config.DisableAutoToolInject + routingChainMaxDepth = config.RoutingChainMaxDepth + } + if routingChainMaxDepth == nil { + defaultDepth := DefaultRoutingChainMaxDepth + routingChainMaxDepth = &defaultDepth } governanceStore, err := NewLocalGovernanceStore(ctx, logger, configStore, governanceConfig, modelCatalog) @@ -196,28 +206,29 @@ func Init( } // 5. Routing engine (dynamically routing requests based on routing rules) - engine, err := NewRoutingEngine(governanceStore, logger) + engine, err := NewRoutingEngine(governanceStore, logger, routingChainMaxDepth) if err != nil { return nil, fmt.Errorf("failed to initialize routing engine: %w", err) } ctx, cancelFunc := context.WithCancel(ctx) plugin := &GovernancePlugin{ - ctx: ctx, - cancelFunc: cancelFunc, - store: governanceStore, - resolver: resolver, - tracker: tracker, - engine: engine, - configStore: configStore, - modelCatalog: modelCatalog, - mcpCatalog: mcpCatalog, - logger: logger, - isVkMandatory: isVkMandatory, - cfgMutex: sync.RWMutex{}, - requiredHeaders: requiredHeaders, - isEnterprise: config != nil && config.IsEnterprise, - inMemoryStore: inMemoryStore, + ctx: ctx, + cancelFunc: cancelFunc, + store: governanceStore, + resolver: resolver, + tracker: tracker, + engine: engine, + configStore: configStore, + modelCatalog: modelCatalog, + mcpCatalog: mcpCatalog, + logger: logger, + isVkMandatory: isVkMandatory, + cfgMutex: sync.RWMutex{}, + requiredHeaders: requiredHeaders, + isEnterprise: config != nil && config.IsEnterprise, + disableAutoToolInject: disableAutoToolInject, + inMemoryStore: inMemoryStore, } return plugin, nil } @@ -259,13 +270,21 @@ func InitFromStore( // Handle nil config - use safe defaults var isVkMandatory *bool var requiredHeaders *[]string + var disableAutoToolInject *bool + var routingChainMaxDepth *int if config != nil { isVkMandatory = config.IsVkMandatory requiredHeaders = config.RequiredHeaders + disableAutoToolInject = config.DisableAutoToolInject + routingChainMaxDepth = config.RoutingChainMaxDepth + } + if routingChainMaxDepth == nil { + defaultDepth := DefaultRoutingChainMaxDepth + routingChainMaxDepth = &defaultDepth } resolver := NewBudgetResolver(governanceStore, modelCatalog, logger) tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) - engine, err := NewRoutingEngine(governanceStore, logger) + engine, err := NewRoutingEngine(governanceStore, logger, routingChainMaxDepth) if err != nil { return nil, fmt.Errorf("failed to initialize routing engine: %w", err) } @@ -288,21 +307,22 @@ func InitFromStore( } ctx, cancelFunc := context.WithCancel(ctx) plugin := &GovernancePlugin{ - ctx: ctx, - cancelFunc: cancelFunc, - store: governanceStore, - resolver: resolver, - tracker: tracker, - engine: engine, - configStore: configStore, - modelCatalog: modelCatalog, - mcpCatalog: mcpCatalog, - logger: logger, - inMemoryStore: inMemoryStore, - isVkMandatory: isVkMandatory, - cfgMutex: sync.RWMutex{}, - requiredHeaders: requiredHeaders, - isEnterprise: config != nil && config.IsEnterprise, + ctx: ctx, + cancelFunc: cancelFunc, + store: governanceStore, + resolver: resolver, + tracker: tracker, + engine: engine, + configStore: configStore, + modelCatalog: modelCatalog, + mcpCatalog: mcpCatalog, + logger: logger, + inMemoryStore: inMemoryStore, + isVkMandatory: isVkMandatory, + cfgMutex: sync.RWMutex{}, + requiredHeaders: requiredHeaders, + isEnterprise: config != nil && config.IsEnterprise, + disableAutoToolInject: disableAutoToolInject, } return plugin, nil } @@ -316,7 +336,7 @@ func (p *GovernancePlugin) GetName() string { func (p *GovernancePlugin) UpdateEnforceAuthOnInference(enforceAuthOnInference bool) { p.cfgMutex.Lock() defer p.cfgMutex.Unlock() - p.isVkMandatory = bifrost.Ptr(enforceAuthOnInference) + p.isVkMandatory = new(enforceAuthOnInference) } // HTTPTransportPreHook intercepts requests before they are processed (governance decision point) @@ -347,7 +367,19 @@ func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req var needsMarshal bool contentType := req.CaseInsensitiveHeaderLookup("Content-Type") - isMultipart := strings.HasPrefix(strings.ToLower(contentType), "multipart/form-data") + lowerCT := strings.ToLower(contentType) + // Strip parameters (e.g., "; charset=utf-8") for clean media type comparison + mediaType := lowerCT + if idx := strings.IndexByte(mediaType, ';'); idx >= 0 { + mediaType = strings.TrimSpace(mediaType[:idx]) + } + isMultipart := strings.HasPrefix(mediaType, "multipart/form-data") + isJSON := mediaType == "" || mediaType == "application/json" || strings.HasSuffix(mediaType, "+json") + + if !isMultipart && !isJSON { + // Non-parseable body (e.g., application/sdp for WebRTC signaling) β€” skip governance + return nil, nil + } var err error if isMultipart { @@ -372,6 +404,22 @@ func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req } } + // Attaching team and customer based on the virtual key + if virtualKey != nil { + if virtualKey.TeamID != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, *virtualKey.TeamID) + } + if virtualKey.Team != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, virtualKey.Team.Name) + } + if virtualKey.CustomerID != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, *virtualKey.CustomerID) + } + if virtualKey.Customer != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, virtualKey.Customer.Name) + } + } + //1. Apply routing rules only if we have rules or matched decision var routingDecision *RoutingDecision if hasRoutingRules { @@ -393,14 +441,27 @@ func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req if err != nil { return nil, err } - //3. Add MCP tools - headers, err := p.addMCPIncludeTools(nil, virtualKey) - if err != nil { - p.logger.Error("failed to add MCP include tools: %v", err) - return nil, nil - } - for header, value := range headers { - req.Headers[header] = value + //3. Add MCP tools only when auto-inject is enabled and header not already set by the caller + p.cfgMutex.RLock() + autoInjectDisabled := p.disableAutoToolInject != nil && *p.disableAutoToolInject + p.cfgMutex.RUnlock() + if !autoInjectDisabled { + // Treat an explicitly-present (even empty) x-bf-mcp-include-tools header as "present" + // so that callers can block auto-injection by sending an empty header value. + headerPresent := false + for k := range req.Headers { + if strings.EqualFold(k, "x-bf-mcp-include-tools") { + headerPresent = true + break + } + } + if !headerPresent { + req.Headers, err = p.addMCPIncludeTools(req.Headers, virtualKey) + if err != nil { + p.logger.Error("failed to add MCP include tools: %v", err) + return nil, nil + } + } } needsMarshal = true } @@ -442,6 +503,22 @@ func (p *GovernancePlugin) governLargePayload(ctx *schemas.BifrostContext, req * virtualKey = vk } + // Attaching team and customer based on the virtual key + if virtualKey != nil { + if virtualKey.TeamID != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, *virtualKey.TeamID) + } + if virtualKey.Team != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, virtualKey.Team.Name) + } + if virtualKey.CustomerID != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, *virtualKey.CustomerID) + } + if virtualKey.Customer != nil { + ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, virtualKey.Customer.Name) + } + } + // Apply routing rules (read-only: decisions still affect downstream evaluation) if hasRoutingRules { var err error @@ -458,14 +535,26 @@ func (p *GovernancePlugin) governLargePayload(ctx *schemas.BifrostContext, req * if err != nil { return nil, err } - // MCP tool headers β€” header-only, no body needed - headers, err := p.addMCPIncludeTools(nil, virtualKey) - if err != nil { - p.logger.Error("failed to add MCP include tools: %v", err) - return nil, nil - } - for header, value := range headers { - req.Headers[header] = value + // MCP tool headers β€” apply the same auto-inject guard as the normal path: + // skip when DisableAutoToolInject is set or the caller already sent the header. + p.cfgMutex.RLock() + autoInjectDisabled := p.disableAutoToolInject != nil && *p.disableAutoToolInject + p.cfgMutex.RUnlock() + if !autoInjectDisabled { + headerPresent := false + for k := range req.Headers { + if strings.EqualFold(k, "x-bf-mcp-include-tools") { + headerPresent = true + break + } + } + if !headerPresent { + req.Headers, err = p.addMCPIncludeTools(req.Headers, virtualKey) + if err != nil { + p.logger.Error("failed to add MCP include tools: %v", err) + return nil, nil + } + } } } @@ -507,19 +596,29 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req if !hasModel { // For genai integration, model is present in URL path instead of the request body if isGeminiPath { - modelValue = req.CaseInsensitivePathParamLookup("model") + // Prefer context value set by a routing rule (format: "provider/model:suffix") + if ctxModel, ok := ctx.Value("model").(string); ok && ctxModel != "" { + modelValue = ctxModel + } else { + modelValue = req.CaseInsensitivePathParamLookup("model") + } } else if isBedrockPath { // For bedrock integration, model is present in URL path as modelId - rawModelID := req.CaseInsensitivePathParamLookup("modelId") - if rawModelID == "" { - return body, nil - } - // URL-decode the modelId (Bedrock model IDs may be URL-encoded, e.g. anthropic%2Fclaude-3-5-sonnet) - decoded, err := url.PathUnescape(rawModelID) - if err != nil { - decoded = rawModelID + // Prefer context value set by a routing rule (format: "provider/model") + if ctxModelID, ok := ctx.Value("modelId").(string); ok && ctxModelID != "" { + modelValue = ctxModelID + } else { + rawModelID := req.CaseInsensitivePathParamLookup("modelId") + if rawModelID == "" { + return body, nil + } + // URL-decode the modelId (Bedrock model IDs may be URL-encoded, e.g. anthropic%2Fclaude-3-5-sonnet) + decoded, err := url.PathUnescape(rawModelID) + if err != nil { + decoded = rawModelID + } + modelValue = decoded } - modelValue = decoded } else { return body, nil } @@ -580,12 +679,8 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(schemas.ModelProvider(config.Provider), modelStr, config.AllowedModels) } else { // Fallback when model catalog is not available: simple string matching - if len(config.AllowedModels) == 0 { - // No restrictions, allow all models - isProviderAllowed = true - } else { - isProviderAllowed = slices.Contains(config.AllowedModels, modelStr) - } + // ["*"] = allow all models; [] = deny all models + isProviderAllowed = config.AllowedModels.IsAllowed(modelStr) } if isProviderAllowed { @@ -617,26 +712,40 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req // No allowed provider configs, continue without modification return body, nil } - // Weighted random selection from allowed providers for the main model - totalWeight := 0.0 + // Separate providers with weight set (participate in routing) from those without (nil weight = excluded from routing) + weightedConfigs := make([]configstoreTables.TableVirtualKeyProviderConfig, 0, len(allowedProviderConfigs)) for _, config := range allowedProviderConfigs { - totalWeight += getWeight(config.Weight) + if config.Weight != nil { + weightedConfigs = append(weightedConfigs, config) + } } - // Generate random number between 0 and totalWeight - randomValue := rand.Float64() * totalWeight - // Select provider based on weighted random selection + var selectedProvider schemas.ModelProvider - currentWeight := 0.0 - for _, config := range allowedProviderConfigs { - currentWeight += getWeight(config.Weight) - if randomValue <= currentWeight { - selectedProvider = schemas.ModelProvider(config.Provider) - break + + if len(weightedConfigs) > 0 { + // Weighted random selection from providers that have weight set + totalWeight := 0.0 + for _, config := range weightedConfigs { + totalWeight += getWeight(config.Weight) } - } - // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) - if selectedProvider == "" && len(allowedProviderConfigs) > 0 { - selectedProvider = schemas.ModelProvider(allowedProviderConfigs[0].Provider) + // Generate random number between 0 and totalWeight + randomValue := rand.Float64() * totalWeight + // Select provider based on weighted random selection + currentWeight := 0.0 + for _, config := range weightedConfigs { + currentWeight += getWeight(config.Weight) + if randomValue <= currentWeight { + selectedProvider = schemas.ModelProvider(config.Provider) + break + } + } + // Fallback: if no provider was selected (shouldn't happen but guard against FP issues) + if selectedProvider == "" { + selectedProvider = schemas.ModelProvider(weightedConfigs[0].Provider) + } + } else { + // No providers have weight set + return body, nil } p.logger.Debug("[Governance] Selected provider: %s", selectedProvider) @@ -667,15 +776,17 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req // Check if fallbacks field is already present _, hasFallbacks := body["fallbacks"] - if !hasFallbacks && len(allowedProviderConfigs) > 1 { - // Sort allowed provider configs by weight (descending) - sort.Slice(allowedProviderConfigs, func(i, j int) bool { - return getWeight(allowedProviderConfigs[i].Weight) > getWeight(allowedProviderConfigs[j].Weight) + // Use the same candidate set that was used for primary selection + fallbackConfigs := weightedConfigs + if !hasFallbacks && len(fallbackConfigs) > 1 { + // Sort fallback configs by weight (descending) + sort.Slice(fallbackConfigs, func(i, j int) bool { + return getWeight(fallbackConfigs[i].Weight) > getWeight(fallbackConfigs[j].Weight) }) // Filter out the selected provider and create fallbacks array - fallbacks := make([]string, 0, len(allowedProviderConfigs)-1) - for _, config := range allowedProviderConfigs { + fallbacks := make([]string, 0, len(fallbackConfigs)-1) + for _, config := range fallbackConfigs { if config.Provider != string(selectedProvider) { var err error refinedModel := modelStr @@ -847,35 +958,55 @@ func (p *GovernancePlugin) applyRoutingRules(ctx *schemas.BifrostContext, req *s // - map[string]string: The updated request headers // - error: Any error that occurred during processing func (p *GovernancePlugin) addMCPIncludeTools(headers map[string]string, virtualKey *configstoreTables.TableVirtualKey) (map[string]string, error) { - if len(virtualKey.MCPConfigs) > 0 { - if headers == nil { - headers = make(map[string]string) - } - executeOnlyTools := make([]string, 0) - for _, vkMcpConfig := range virtualKey.MCPConfigs { - if len(vkMcpConfig.ToolsToExecute) == 0 { - // No tools specified in virtual key config - skip this client entirely - continue - } - // Handle wildcard in virtual key config - allow all tools from this client - if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { - // Virtual key uses wildcard - use client-specific wildcard - executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) - continue - } + if headers == nil { + headers = make(map[string]string) + } - for _, tool := range vkMcpConfig.ToolsToExecute { - if tool != "" { - // Add the tool - client config filtering will be handled by mcp.go - executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) - } + executeOnlyTools := make([]string, 0) + + // Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName + var allowAllVKsClients map[string]string + if p.inMemoryStore != nil { + allowAllVKsClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys() + } + if allowAllVKsClients == nil { + allowAllVKsClients = make(map[string]string) + } + + // Process VK-specific MCP configs first β€” explicit config always overrides AllowOnAllVirtualKeys. + // Track which AllowOnAllVirtualKeys clients have an explicit VK config so we don't double-add them. + handledClients := make(map[string]bool) + for _, vkMcpConfig := range virtualKey.MCPConfigs { + clientID := vkMcpConfig.MCPClient.ClientID + if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll { + // Explicit VK config exists β€” it takes precedence; mark as handled regardless of tool list + handledClients[clientID] = true + } + if vkMcpConfig.ToolsToExecute.IsEmpty() { + // No tools specified in virtual key config - skip this client entirely + continue + } + if vkMcpConfig.ToolsToExecute.IsUnrestricted() { + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) + continue + } + for _, tool := range vkMcpConfig.ToolsToExecute { + if tool != "" { + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) } } + } - // Set even when empty to exclude tools when no tools are present in the virtual key config - headers["x-bf-mcp-include-tools"] = strings.Join(executeOnlyTools, ",") + // For AllowOnAllVirtualKeys clients with no explicit VK config, fall back to allowing all tools + for clientID, clientName := range allowAllVKsClients { + if !handledClients[clientID] { + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName)) + } } + // Set even when empty to exclude tools when no tools are present in the virtual key config + headers["x-bf-mcp-include-tools"] = strings.Join(executeOnlyTools, ",") + return headers, nil } @@ -908,7 +1039,7 @@ func (p *GovernancePlugin) validateRequiredHeaders(ctx *schemas.BifrostContext) return nil } -// evaluateGovernanceRequest is a common function that handles virtual key validation +// EvaluateGovernanceRequest is a common function that handles virtual key validation // and governance evaluation logic. It returns the evaluation result and a BifrostError // if the request should be rejected, or nil if allowed. // @@ -919,7 +1050,7 @@ func (p *GovernancePlugin) validateRequiredHeaders(ctx *schemas.BifrostContext) // Returns: // - *EvaluationResult: The governance evaluation result // - *schemas.BifrostError: The error to return if request is not allowed, nil if allowed -func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) { +func (p *GovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) { // Check if authentication is mandatory (either VK or user auth) // Checking if the virtual key is valid or not isVirtualKeyValid := false @@ -927,6 +1058,15 @@ func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext _, exists := p.store.GetVirtualKey(evaluationRequest.VirtualKey) if exists { isVirtualKeyValid = true + } else { + // VK was provided but does not exist in the store β€” reject regardless of mandatory setting + return nil, &schemas.BifrostError{ + Type: bifrost.Ptr("virtual_key_not_found"), + StatusCode: bifrost.Ptr(401), + Error: &schemas.ErrorField{ + Message: "virtual key not found. The provided virtual key does not exist or has been revoked.", + }, + } } } p.cfgMutex.RLock() @@ -958,10 +1098,37 @@ func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext if result.Decision == DecisionAllow && evaluationRequest.VirtualKey != "" { if evaluationRequest.UserID != "" { // User auth present: only use VK for routing/filtering (skip rate limits and budgets) - result = p.resolver.EvaluateVirtualKeyFiltering(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType) + result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, true) } else { // No user auth: full VK governance (routing + limits) - result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType) + result = p.resolver.EvaluateVirtualKeyRequest(ctx, evaluationRequest.VirtualKey, evaluationRequest.Provider, evaluationRequest.Model, requestType, false) + } + } + + // Check the actual MCP tools injected into the request against the VK MCPConfigs. + // BifrostContextKeyMCPAddedTools is populated by AddToolsToRequest (which runs before + // PreLLMHook), so it contains the real expanded tool names (e.g. "youtube-search") rather + // than raw header patterns (e.g. "youtube-*"), giving us exact per-tool validation. + if result.Decision == DecisionAllow && result.VirtualKey != nil { + if addedTools, ok := ctx.Value(schemas.BifrostContextKeyMCPAddedTools).([]string); ok && len(addedTools) > 0 { + // Fetch once before the loop to avoid repeated lock acquisitions per tool. + var allowAllClients map[string]string + if p.inMemoryStore != nil { + allowAllClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys() + } + var disallowed []string + for _, tool := range addedTools { + if !p.isMCPToolAllowedByVKWith(result.VirtualKey, tool, allowAllClients) { + disallowed = append(disallowed, tool) + } + } + if len(disallowed) > 0 { + result = &EvaluationResult{ + Decision: DecisionMCPToolBlocked, + Reason: fmt.Sprintf("MCP tools not allowed for virtual key '%s': %s", result.VirtualKey.Name, strings.Join(disallowed, ", ")), + VirtualKey: result.VirtualKey, + } + } } } @@ -1006,6 +1173,15 @@ func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext }, } + case DecisionMCPToolBlocked: + return result, &schemas.BifrostError{ + Type: bifrost.Ptr(string(result.Decision)), + StatusCode: bifrost.Ptr(403), + Error: &schemas.ErrorField{ + Message: result.Reason, + }, + } + default: // Fallback to deny for unknown decisions return result, &schemas.BifrostError{ @@ -1017,6 +1193,53 @@ func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext } } +// isMCPToolAllowedByVK checks whether a tool pattern (in "clientName-toolName" or "clientName-*" +// format) is permitted by the virtual key's MCPConfigs. +// +// Priority order: +// 1. If the VK has an explicit MCP config for this client, that config is authoritative (can allow or deny). +// 2. If no explicit config exists and the client has AllowOnAllVirtualKeys=true, all tools are allowed. +// +// For wildcard patterns ("clientName-*"): allowed if VK has the client configured with any tools. +// Specific tool enforcement happens at execution time via checkVKMCPToolAllowance. +// For specific tools ("clientName-toolName"): allowed if VK has "*" or the exact tool name. +func (p *GovernancePlugin) isMCPToolAllowedByVK(vk *configstoreTables.TableVirtualKey, toolPattern string) bool { + var allowAllClients map[string]string + if p.inMemoryStore != nil { + allowAllClients = p.inMemoryStore.GetMCPClientsAllowingAllVirtualKeys() + } + return p.isMCPToolAllowedByVKWith(vk, toolPattern, allowAllClients) +} + +// isMCPToolAllowedByVKWith checks whether a tool pattern is allowed by the virtual key, +// using a pre-fetched allowAllClients map (clientID β†’ clientName) to avoid repeated lock +// acquisitions in loops. +func (p *GovernancePlugin) isMCPToolAllowedByVKWith(vk *configstoreTables.TableVirtualKey, toolPattern string, allowAllClients map[string]string) bool { + // Check VK-specific MCP configs first β€” explicit config always overrides AllowOnAllVirtualKeys. + for _, mcpConfig := range vk.MCPConfigs { + clientName := mcpConfig.MCPClient.Name + if toolPattern != clientName+"-*" && !strings.HasPrefix(toolPattern, clientName+"-") { + continue + } + // Found an explicit config for this client β€” use it; do not fall back to AllowOnAllVirtualKeys. + if toolPattern == clientName+"-*" { + return !mcpConfig.ToolsToExecute.IsEmpty() + } + if mcpConfig.ToolsToExecute.IsUnrestricted() { + return true + } + toolSuffix := strings.TrimPrefix(toolPattern, clientName+"-") + return mcpConfig.ToolsToExecute.Contains(toolSuffix) + } + // No explicit VK config found β€” fall back to AllowOnAllVirtualKeys (allows all tools). + for _, clientName := range allowAllClients { + if strings.HasPrefix(toolPattern, clientName+"-") || toolPattern == clientName+"-*" { + return true + } + } + return false +} + // PreLLMHook intercepts requests before they are processed (governance decision point) // Parameters: // - ctx: The Bifrost context @@ -1049,7 +1272,7 @@ func (p *GovernancePlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas. UserID: userID, } // Evaluate governance using common function - _, bifrostError := p.evaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType) + _, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType) // Convert BifrostError to LLMPluginShortCircuit if needed if bifrostError != nil { return req, &schemas.LLMPluginShortCircuit{ @@ -1076,7 +1299,7 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche } // Extract request type, provider, and model - requestType, provider, model := bifrost.GetResponseFields(result, err) + requestType, provider, requestedModel, _ := bifrost.GetResponseFields(result, err) // Extract governance information virtualKey := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) @@ -1084,20 +1307,6 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche // Extract user ID for enterprise user-level governance userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceUserID) - // Extract cache and batch flags from context - isCacheRead := false - isBatch := false - if val := ctx.Value(governanceIsCacheReadContextKey); val != nil { - if b, ok := val.(bool); ok { - isCacheRead = b - } - } - if val := ctx.Value(governanceIsBatchContextKey); val != nil { - if b, ok := val.(bool); ok { - isBatch = b - } - } - if requestType == schemas.ListModelsRequest && result != nil && result.ListModelsResponse != nil && virtualKey != "" { // filter models which are not supported on this virtual key result.ListModelsResponse.Data = p.filterModelsForVirtualKey(result.ListModelsResponse.Data, virtualKey) @@ -1105,6 +1314,9 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche isFinalChunk := bifrost.IsFinalChunk(ctx) + // Build pricing scopes from context using the governance VK ID (not the raw VK token) + pricingScopes := modelcatalog.PricingLookupScopesFromContext(ctx, string(provider)) + // Always process usage tracking (with or without virtual key) // When user auth is present, skip VK usage tracking to avoid double-counting effectiveVK := virtualKey @@ -1113,11 +1325,12 @@ func (p *GovernancePlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche } // If effectiveVK is empty, it will be passed as empty string to postHookWorker // The tracker will handle empty virtual keys gracefully by only updating provider-level and model-level usage - if model != "" { + if requestedModel != "" { p.wg.Add(1) go func() { defer p.wg.Done() - p.postHookWorker(result, provider, model, requestType, effectiveVK, requestID, userID, isCacheRead, isBatch, isFinalChunk) + // Use the requested model for usage tracking + p.postHookWorker(result, provider, requestedModel, requestType, effectiveVK, requestID, userID, isFinalChunk, pricingScopes) }() } @@ -1158,7 +1371,7 @@ func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas. } // Evaluate governance using common function - _, bifrostError := p.evaluateGovernanceRequest(ctx, evaluationRequest, schemas.MCPToolExecutionRequest) + _, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, schemas.MCPToolExecutionRequest) // Convert BifrostError to MCPPluginShortCircuit if needed if bifrostError != nil { @@ -1167,6 +1380,34 @@ func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas. }, nil } + // Blind single-tool check: validate the specific tool being executed against VK MCPConfigs. + // This runs independently of EvaluateGovernanceRequest to enforce execution-time allow-list. + if virtualKeyValue != "" { + vk, ok := p.store.GetVirtualKey(virtualKeyValue) + if !ok || vk == nil || !vk.IsActive { + // VK became invalid after initial check - fail closed for security + ctx.SetValue(governanceRejectedContextKey, true) + return req, &schemas.MCPPluginShortCircuit{Error: &schemas.BifrostError{ + Type: bifrost.Ptr(string(DecisionVirtualKeyNotFound)), + StatusCode: bifrost.Ptr(403), + Error: &schemas.ErrorField{ + Message: "Virtual key not found", + }, + }}, nil + } + if !p.isMCPToolAllowedByVK(vk, toolName) { + ctx.SetValue(governanceRejectedContextKey, true) + return req, &schemas.MCPPluginShortCircuit{Error: &schemas.BifrostError{ + Type: bifrost.Ptr(string(DecisionMCPToolBlocked)), + StatusCode: bifrost.Ptr(403), + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("MCP tool '%s' is not allowed for virtual key '%s'", toolName, vk.Name), + }, + }}, nil + } + return req, nil, nil + } + return req, nil, nil } @@ -1263,13 +1504,15 @@ func (p *GovernancePlugin) Cleanup() error { // - provider: The provider of the request // - model: The model of the request // - requestType: The type of the request -// - virtualKey: The virtual key of the request (empty string if not present) +// - virtualKey: The raw virtual key token of the request (empty string if not present) +// - selectedKeyID: The selected provider key ID used for scoped pricing overrides // - requestID: The request ID // - userID: The user ID for enterprise user-level governance (empty string if not present) // - isCacheRead: Whether the request is a cache read // - isBatch: Whether the request is a batch request // - isFinalChunk: Whether the request is the final chunk -func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID, userID string, _, _, isFinalChunk bool) { +// - pricingScopes: Prebuilt pricing lookup scopes using governance VK ID (nil if not applicable) +func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType, virtualKey, requestID, userID string, isFinalChunk bool, pricingScopes *modelcatalog.PricingLookupScopes) { // Determine if request was successful success := (result != nil) @@ -1279,7 +1522,7 @@ func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provi if !isStreaming || (isStreaming && isFinalChunk) { var cost float64 if p.modelCatalog != nil && result != nil { - cost = p.modelCatalog.CalculateCost(result) + cost = p.modelCatalog.CalculateCost(result, pricingScopes) } tokensUsed := 0 if result != nil { diff --git a/plugins/governance/model_provider_governance_test.go b/plugins/governance/modelprovidergovernance_test.go similarity index 99% rename from plugins/governance/model_provider_governance_test.go rename to plugins/governance/modelprovidergovernance_test.go index a17855552a..a7889f14bd 100644 --- a/plugins/governance/model_provider_governance_test.go +++ b/plugins/governance/modelprovidergovernance_test.go @@ -1480,6 +1480,9 @@ func TestPreLLMHook_ModelProviderPass_VirtualKeyChecksPass(t *testing.T) { // Model/provider checks pass (no limits) // Virtual key checks also pass vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, }, nil) @@ -1702,9 +1705,9 @@ func TestPostHook_UpdatesProviderBudgetUsage_NoVirtualKey(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } @@ -1771,9 +1774,9 @@ func TestPostHook_UpdatesProviderRateLimitUsage_NoVirtualKey(t *testing.T) { TotalTokens: 10000, // 10000 tokens used (exactly at limit) }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } @@ -1838,9 +1841,9 @@ func TestPostHook_UpdatesModelBudgetUsage_NoVirtualKey(t *testing.T) { TotalTokens: 1500, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } @@ -1907,9 +1910,9 @@ func TestPostHook_UpdatesModelRateLimitUsage_NoVirtualKey(t *testing.T) { TotalTokens: 10000, // 10000 tokens used (exactly at limit) }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.OpenAI, - ModelRequested: "gpt-4", + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4", }, }, } diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 8d3da777af..bf184908d3 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -4,7 +4,6 @@ package governance import ( "context" "fmt" - "slices" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" @@ -24,6 +23,7 @@ const ( DecisionRequestLimited Decision = "request_limited" DecisionModelBlocked Decision = "model_blocked" DecisionProviderBlocked Decision = "provider_blocked" + DecisionMCPToolBlocked Decision = "mcp_tool_blocked" ) // EvaluationRequest contains the context for evaluating a request @@ -170,7 +170,8 @@ func (r *BudgetResolver) isModelRequired(requestType schemas.RequestType) bool { } // EvaluateVirtualKeyRequest evaluates virtual key-specific checks including validation, filtering, rate limits, and budgets -func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestType schemas.RequestType) *EvaluationResult { +// skipRateLimitsAndBudgets evaluates to true when we want to skip rate limits and budgets. This is used when user auth is present (user governance handles limits). +func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestType schemas.RequestType, skipRateLimitsAndBudgets bool) *EvaluationResult { // 1. Validate virtual key exists and is active vk, exists := r.store.GetVirtualKey(virtualKeyValue) if !exists { @@ -224,109 +225,45 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, } // 4. Check rate limits hierarchy (VK level) - if rateLimitResult := r.checkRateLimitHierarchy(ctx, vk, evaluationRequest); rateLimitResult != nil { - return rateLimitResult - } - - // 5. Check budget hierarchy (VK β†’ Team β†’ Customer) - if budgetResult := r.checkBudgetHierarchy(ctx, vk, evaluationRequest); budgetResult != nil { - return budgetResult - } - - // Find the provider config that matches the request's provider and get its allowed keys - for _, pc := range vk.ProviderConfigs { - if schemas.ModelProvider(pc.Provider) == provider && len(pc.Keys) > 0 { - includeOnlyKeys := make([]string, 0, len(pc.Keys)) - for _, dbKey := range pc.Keys { - includeOnlyKeys = append(includeOnlyKeys, dbKey.KeyID) - } - ctx.SetValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys, includeOnlyKeys) - break + if !skipRateLimitsAndBudgets { + if rateLimitResult := r.checkRateLimitHierarchy(ctx, vk, evaluationRequest); rateLimitResult != nil { + return rateLimitResult } - } - // All checks passed - return &EvaluationResult{ - Decision: DecisionAllow, - Reason: "Request allowed by governance policy", - VirtualKey: vk, - } -} - -// EvaluateVirtualKeyFiltering evaluates virtual key checks for routing and model/provider filtering only, -// skipping rate limits and budgets. Used when user auth is present (user governance handles limits). -func (r *BudgetResolver) EvaluateVirtualKeyFiltering(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestType schemas.RequestType) *EvaluationResult { - // 1. Validate virtual key exists and is active - vk, exists := r.store.GetVirtualKey(virtualKeyValue) - if !exists { - return &EvaluationResult{ - Decision: DecisionVirtualKeyNotFound, - Reason: "Virtual key not found", - } - } - // Set virtual key id and name in context - ctx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, vk.ID) - ctx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyName, vk.Name) - if vk.Team != nil { - ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamID, vk.Team.ID) - ctx.SetValue(schemas.BifrostContextKeyGovernanceTeamName, vk.Team.Name) - if vk.Team.Customer != nil { - ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, vk.Team.Customer.ID) - ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, vk.Team.Customer.Name) - } - } - if vk.Customer != nil { - ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerID, vk.Customer.ID) - ctx.SetValue(schemas.BifrostContextKeyGovernanceCustomerName, vk.Customer.Name) - } - if !vk.IsActive { - return &EvaluationResult{ - Decision: DecisionVirtualKeyBlocked, - Reason: "Virtual key is inactive", - } - } - // 2. Check provider filtering - if requestType != schemas.MCPToolExecutionRequest && !r.isProviderAllowed(vk, provider) { - return &EvaluationResult{ - Decision: DecisionProviderBlocked, - Reason: fmt.Sprintf("Provider '%s' is not allowed for this virtual key", provider), - VirtualKey: vk, - } - } - // 3. Check model filtering - if r.isModelRequired(requestType) && !r.isModelAllowed(vk, provider, model) { - return &EvaluationResult{ - Decision: DecisionModelBlocked, - Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", model), - VirtualKey: vk, + // 5. Check budget hierarchy (VK β†’ Team β†’ Customer) + if budgetResult := r.checkBudgetHierarchy(ctx, vk, evaluationRequest); budgetResult != nil { + return budgetResult } } - // Set include-only keys for provider config routing + // Find the provider config that matches the request's provider and apply key filtering for _, pc := range vk.ProviderConfigs { - if schemas.ModelProvider(pc.Provider) == provider && len(pc.Keys) > 0 { - includeOnlyKeys := make([]string, 0, len(pc.Keys)) - for _, dbKey := range pc.Keys { - includeOnlyKeys = append(includeOnlyKeys, dbKey.KeyID) + if schemas.ModelProvider(pc.Provider) == provider { + if !pc.AllowAllKeys { + // Restrict to specific keys (empty slice = no keys allowed) + includeOnlyKeys := make([]string, 0, len(pc.Keys)) + for _, dbKey := range pc.Keys { + includeOnlyKeys = append(includeOnlyKeys, dbKey.KeyID) + } + ctx.SetValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys, includeOnlyKeys) } - ctx.SetValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys, includeOnlyKeys) break } } - // Skip rate limits and budgets β€” user auth handles those + // All checks passed return &EvaluationResult{ Decision: DecisionAllow, - Reason: "Request allowed by governance policy (VK filtering only)", + Reason: "Request allowed by governance policy", VirtualKey: vk, } } // isModelAllowed checks if the requested model is allowed for this VK func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, model string) bool { - // Empty ProviderConfigs means all models are allowed + // Empty ProviderConfigs means no models are allowed (deny-by-default) if len(vk.ProviderConfigs) == 0 { - return true + return false } for _, pc := range vk.ProviderConfigs { @@ -338,10 +275,8 @@ func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, p return r.modelCatalog.IsModelAllowedForProvider(provider, model, pc.AllowedModels) } // Fallback when model catalog is not available: simple string matching - if len(pc.AllowedModels) == 0 { - return true - } - return slices.Contains(pc.AllowedModels, model) + // ["*"] = allow all models; [] = deny all models + return pc.AllowedModels.IsAllowed(model) } } @@ -350,9 +285,9 @@ func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, p // isProviderAllowed checks if the requested provider is allowed for this VK func (r *BudgetResolver) isProviderAllowed(vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider) bool { - // Empty AllowedProviders means all providers are allowed + // Empty ProviderConfigs means no providers are allowed (deny-by-default) if len(vk.ProviderConfigs) == 0 { - return true + return false } for _, pc := range vk.ProviderConfigs { @@ -418,7 +353,7 @@ func (r *BudgetResolver) isProviderBudgetViolated(ctx context.Context, vk *confi } // 2. Check VK-level provider config budget - if config.Budget == nil { + if len(config.Budgets) == 0 { return false } if err := r.store.CheckBudget(ctx, vk, request, nil); err != nil { diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go index 7e55c57328..fceb69a737 100644 --- a/plugins/governance/resolver_test.go +++ b/plugins/governance/resolver_test.go @@ -17,6 +17,9 @@ import ( func TestBudgetResolver_EvaluateRequest_AllowedRequest(t *testing.T) { logger := NewMockLogger() vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, @@ -26,7 +29,7 @@ func TestBudgetResolver_EvaluateRequest_AllowedRequest(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionAllow, result) assertVirtualKeyFound(t, result) @@ -41,7 +44,7 @@ func TestBudgetResolver_EvaluateRequest_VirtualKeyNotFound(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-nonexistent", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-nonexistent", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionVirtualKeyNotFound, result) } @@ -59,7 +62,7 @@ func TestBudgetResolver_EvaluateRequest_VirtualKeyBlocked(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionVirtualKeyBlocked, result) } @@ -83,7 +86,7 @@ func TestBudgetResolver_EvaluateRequest_ProviderBlocked(t *testing.T) { ctx := &schemas.BifrostContext{} // Try to use OpenAI (not allowed) - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionProviderBlocked, result) assertVirtualKeyFound(t, result) @@ -100,7 +103,6 @@ func TestBudgetResolver_EvaluateRequest_ModelBlocked(t *testing.T) { AllowedModels: []string{"gpt-4", "gpt-4-turbo"}, // Only these models Weight: bifrost.Ptr(1.0), RateLimit: nil, - Budget: nil, Keys: []configstoreTables.TableKey{}, }, } @@ -115,7 +117,7 @@ func TestBudgetResolver_EvaluateRequest_ModelBlocked(t *testing.T) { ctx := &schemas.BifrostContext{} // Try to use gpt-4o-mini (not in allowed list) - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4o-mini", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4o-mini", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionModelBlocked, result) } @@ -137,7 +139,7 @@ func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_TokenLimit(t *testing. resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionTokenLimited, result) assertRateLimitInfo(t, result) @@ -160,7 +162,7 @@ func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_RequestLimit(t *testin resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionRequestLimited, result) } @@ -198,7 +200,7 @@ func TestBudgetResolver_EvaluateRequest_RateLimitExpired(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) // Should allow because rate limit was expired and has been reset assertDecision(t, DecisionAllow, result) @@ -220,7 +222,7 @@ func TestBudgetResolver_EvaluateRequest_BudgetExceeded(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionBudgetExceeded, result) } @@ -247,7 +249,7 @@ func TestBudgetResolver_EvaluateRequest_BudgetExpired(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) // Should allow because budget is expired (will be reset) assertDecision(t, DecisionAllow, result) @@ -282,7 +284,7 @@ func TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy(t *testing.T) ctx := &schemas.BifrostContext{} // Test: All under limit should pass - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionAllow, result) // Test: VK budget exceeds should fail @@ -293,7 +295,7 @@ func TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy(t *testing.T) vkBudgetToUpdate.CurrentUsage = 100.0 store.budgets.Store("vk-budget", vkBudgetToUpdate) } - result = resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result = resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionBudgetExceeded, result) } @@ -315,7 +317,7 @@ func TestBudgetResolver_EvaluateRequest_ProviderLevelRateLimit(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionTokenLimited, result) assertRateLimitInfo(t, result) @@ -338,7 +340,7 @@ func TestBudgetResolver_CheckRateLimits_BothExceeded(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assertDecision(t, DecisionRateLimited, result) assert.Contains(t, result.Reason, "rate limit") @@ -359,10 +361,10 @@ func TestBudgetResolver_IsProviderAllowed(t *testing.T) { shouldBeAllowed bool }{ { - name: "No provider configs (all allowed)", + name: "No provider configs (none allowed - deny-by-default)", vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), provider: schemas.OpenAI, - shouldBeAllowed: true, + shouldBeAllowed: false, }, { name: "Provider in allowlist", @@ -408,22 +410,32 @@ func TestBudgetResolver_IsModelAllowed(t *testing.T) { shouldBeAllowed bool }{ { - name: "No provider configs (all models allowed)", + name: "No provider configs (no models allowed - deny-by-default)", vk: buildVirtualKey("vk1", "sk-bf-test", "Test", true), provider: schemas.OpenAI, model: "gpt-4", - shouldBeAllowed: true, + shouldBeAllowed: false, }, { - name: "Empty allowed models (all models allowed)", + name: "Wildcard allowed models (all models allowed)", vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", []configstoreTables.TableVirtualKeyProviderConfig{ - buildProviderConfig("openai", []string{}), // Empty = all allowed + buildProviderConfig("openai", []string{"*"}), // ["*"] = allow all }), provider: schemas.OpenAI, model: "gpt-4", shouldBeAllowed: true, }, + { + name: "Empty allowed models (deny all)", + vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", + []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{}), // [] = deny all + }), + provider: schemas.OpenAI, + model: "gpt-4", + shouldBeAllowed: false, + }, { name: "Model in allowlist", vk: buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test", @@ -458,6 +470,9 @@ func TestBudgetResolver_IsModelAllowed(t *testing.T) { func TestBudgetResolver_ContextPopulation(t *testing.T) { logger := NewMockLogger() vk := buildVirtualKey("vk1", "sk-bf-test", "Test VK", true) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } customer := buildCustomer("cust1", "Customer 1", nil) team := buildTeam("team1", "Team 1", nil) team.CustomerID = &customer.ID @@ -476,7 +491,7 @@ func TestBudgetResolver_ContextPopulation(t *testing.T) { resolver := NewBudgetResolver(store, nil, logger) ctx := &schemas.BifrostContext{} - result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) assert.Equal(t, DecisionAllow, result.Decision) diff --git a/plugins/governance/routing.go b/plugins/governance/routing.go index b5592a6049..b32044be02 100644 --- a/plugins/governance/routing.go +++ b/plugins/governance/routing.go @@ -11,6 +11,9 @@ import ( configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) +// DefaultRoutingChainMaxDepth is the default maximum depth for routing rule chain evaluation. +const DefaultRoutingChainMaxDepth = 10 + // headerKeyPattern matches header map access patterns like headers["X-Api-Key"] or headers['X-Api-Key'] var headerKeyPattern = regexp.MustCompile(`headers\[["']([^"']+)["']\]`) @@ -54,28 +57,41 @@ type RoutingContext struct { } type RoutingEngine struct { - store GovernanceStore - logger schemas.Logger + store GovernanceStore + logger schemas.Logger + chainMaxDepth *int // pointer to live config value; changes are reflected immediately } // NewRoutingEngine creates a new RoutingEngine -func NewRoutingEngine(store GovernanceStore, logger schemas.Logger) (*RoutingEngine, error) { +func NewRoutingEngine(store GovernanceStore, logger schemas.Logger, chainMaxDepth *int) (*RoutingEngine, error) { if store == nil { return nil, fmt.Errorf("store cannot be nil") } - if logger == nil { return nil, fmt.Errorf("logger cannot be nil") } + if chainMaxDepth == nil { + return nil, fmt.Errorf("chainMaxDepth cannot be nil") + } + if *chainMaxDepth <= 0 { + return nil, fmt.Errorf("chainMaxDepth must be greater than 0") + } return &RoutingEngine{ - store: store, - logger: logger, + store: store, + logger: logger, + chainMaxDepth: chainMaxDepth, }, nil } -// EvaluateRoutingRules evaluates routing rules for a given context and returns routing decision -// Implements scope precedence: VirtualKey > Team > Customer > Global (first-match-wins) +// EvaluateRoutingRules evaluates routing rules for a given context and returns a routing decision. +// Implements scope precedence: VirtualKey > Team > Customer > Global (first-match-wins within each iteration). +// When a matched rule has chain_rule=true, the resolved provider/model is fed back into the evaluator +// and the full scope chain is re-evaluated with the updated context. This repeats until: +// 1. No rule matches the current context +// 2. A terminal rule matches (chain_rule=false, the default) +// 3. A cycle is detected (a provider/model state was already visited) +// 4. The chain exceeds the configured max depth (chainMaxDepth, default 10) func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routingCtx *RoutingContext) (*RoutingDecision, error) { if routingCtx == nil { return nil, fmt.Errorf("routing context cannot be nil") @@ -83,64 +99,95 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi re.logger.Debug("[RoutingEngine] Starting rule evaluation for provider=%s, model=%s", routingCtx.Provider, routingCtx.Model) - // Extract CEL variables from routing context - variables, err := extractRoutingVariables(routingCtx) - if err != nil { - re.logger.Error("[RoutingEngine] Failed to extract routing variables: %v", err) - return nil, fmt.Errorf("failed to extract routing variables: %w", err) + // Mutable provider/model that advances through the chain; all other context fields are immutable. + currentProvider := routingCtx.Provider + currentModel := routingCtx.Model + + // Track visited provider/model states to detect cycles (e.g. Aβ†’Bβ†’A). + visited := map[string]struct{}{ + fmt.Sprintf("%s|%s", currentProvider, currentModel): {}, } - // Determine scope chain based on organizational hierarchy - scopeChain := buildScopeChain(routingCtx.VirtualKey) - re.logger.Debug("[RoutingEngine] Scope chain: %v", scopeChainToStrings(scopeChain)) - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Scope chain: %v", scopeChainToStrings(scopeChain))) + var finalDecision *RoutingDecision - // Evaluate rules in scope precedence order (first-match-wins) - for _, scope := range scopeChain { - scopeID := scope.ScopeID + for chainStep := 0; ; chainStep++ { + // TERMINATION 4: Chain exceeded configured max depth. + maxDepth := *re.chainMaxDepth + if chainStep >= maxDepth { + re.logger.Warn("[RoutingEngine] Routing rule chain exceeded max depth (%d), stopping", maxDepth) + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain exceeded max depth (%d) at step %d, stopping. Final resolved: provider=%s, model=%s", maxDepth, chainStep, currentProvider, currentModel)) + break + } - // Get all enabled rules for this scope, ordered by priority ASC - rules := re.store.GetScopedRoutingRules(scope.ScopeName, scopeID) - re.logger.Debug("[RoutingEngine] Evaluating scope=%s, scopeID=%s, ruleCount=%d", scope.ScopeName, scopeID, len(rules)) + if chainStep > 0 { + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain step %d: re-evaluating with provider=%s, model=%s", chainStep, currentProvider, currentModel)) + } - if len(rules) == 0 { - continue + // Build CEL variables for the current chain step's provider/model. + iterCtx := *routingCtx + iterCtx.Provider = currentProvider + iterCtx.Model = currentModel + // Refresh budget/rate-limit status for the current provider/model so chained + // rules that test budget_used, tokens_used, or request see fresh data. + iterCtx.BudgetAndRateLimitStatus = re.store.GetBudgetAndRateLimitStatus(ctx, currentModel, currentProvider, routingCtx.VirtualKey, nil, nil, nil) + + variables, err := extractRoutingVariables(&iterCtx) + if err != nil { + re.logger.Error("[RoutingEngine] Failed to extract routing variables: %v", err) + return nil, fmt.Errorf("failed to extract routing variables: %w", err) } - ruleNames := make([]string, 0, len(rules)) - for _, r := range rules { - ruleNames = append(ruleNames, r.Name) + scopeChain := buildScopeChain(routingCtx.VirtualKey) + re.logger.Debug("[RoutingEngine] Scope chain (step=%d): %v", chainStep, scopeChainToStrings(scopeChain)) + if chainStep == 0 { + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Scope chain: %v", scopeChainToStrings(scopeChain))) } - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating scope %s: %d rules [%s]", scope.ScopeName, len(rules), strings.Join(ruleNames, ", "))) - // Evaluate each rule - for _, rule := range rules { - re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression) + var stepDecision *RoutingDecision + var matchedRule *configstoreTables.TableRoutingRule + var matchedTargetWeight float64 - // Get or compile and cache the CEL program - program, err := re.store.GetRoutingProgram(rule) - if err != nil { - re.logger.Warn("[RoutingEngine] Failed to compile rule %s: %v", rule.Name, err) - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: compile error: %v", rule.Name, err)) + outerLoop: + for _, scope := range scopeChain { + scopeID := scope.ScopeID + + rules := re.store.GetScopedRoutingRules(scope.ScopeName, scopeID) + re.logger.Debug("[RoutingEngine] Evaluating scope=%s, scopeID=%s, ruleCount=%d", scope.ScopeName, scopeID, len(rules)) + + if len(rules) == 0 { continue } - // Evaluate the CEL expression - matched, err := evaluateCELExpression(program, variables) - if err != nil { - re.logger.Warn("[RoutingEngine] Failed to evaluate rule %s: %v", rule.Name, err) - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: eval error: %v", rule.Name, err)) - continue + ruleNames := make([]string, 0, len(rules)) + for _, r := range rules { + ruleNames = append(ruleNames, r.Name) } + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Evaluating scope %s: %d rules [%s]", scope.ScopeName, len(rules), strings.Join(ruleNames, ", "))) - re.logger.Debug("[RoutingEngine] Rule %s evaluation result: matched=%v", rule.Name, matched) + for _, rule := range rules { + re.logger.Debug("[RoutingEngine] Evaluating rule: name=%s, expression=%s", rule.Name, rule.CelExpression) - if !matched { - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] β†’ no match", rule.Name, rule.CelExpression)) - } + program, err := re.store.GetRoutingProgram(rule) + if err != nil { + re.logger.Warn("[RoutingEngine] Failed to compile rule %s: %v", rule.Name, err) + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: compile error: %v", rule.Name, err)) + continue + } + + matched, err := evaluateCELExpression(program, variables) + if err != nil { + re.logger.Warn("[RoutingEngine] Failed to evaluate rule %s: %v", rule.Name, err) + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' skipped: eval error: %v", rule.Name, err)) + continue + } + + re.logger.Debug("[RoutingEngine] Rule %s evaluation result: matched=%v", rule.Name, matched) + + if !matched { + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] β†’ no match", rule.Name, rule.CelExpression)) + continue + } - // If rule matched, select a target probabilistically and return routing decision - if matched { target, ok := selectWeightedTarget(rule.Targets) if !ok { re.logger.Debug("[RoutingEngine] Rule %s matched but has no valid targets (empty list or all-negative weights), skipping β€” note: all-zero weights use uniform selection and would not reach here", rule.Name) @@ -148,12 +195,12 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi continue } - provider := string(routingCtx.Provider) + provider := string(currentProvider) if target.Provider != nil && *target.Provider != "" { provider = *target.Provider } - model := routingCtx.Model + model := currentModel if target.Model != nil && *target.Model != "" { model = *target.Model } @@ -163,7 +210,7 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi keyID = *target.KeyID } - decision := &RoutingDecision{ + stepDecision = &RoutingDecision{ Provider: provider, Model: model, KeyID: keyID, @@ -171,21 +218,52 @@ func (re *RoutingEngine) EvaluateRoutingRules(ctx *schemas.BifrostContext, routi MatchedRuleID: rule.ID, MatchedRuleName: rule.Name, } + matchedRule = rule + matchedTargetWeight = target.Weight + break outerLoop + } + } - ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleID, rule.ID) - ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleName, rule.Name) + // TERMINATION 1: No rule matched this iteration. + if stepDecision == nil { + break + } - re.logger.Debug("[RoutingEngine] Rule matched! Selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v", target.Weight, provider, model, rule.ParsedFallbacks) - ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] β†’ matched, selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v", rule.Name, rule.CelExpression, target.Weight, provider, model, rule.ParsedFallbacks)) - return decision, nil - } + // Accumulate: last match wins for all fields. + finalDecision = stepDecision + ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleID, stepDecision.MatchedRuleID) + ctx.SetValue(schemas.BifrostContextKeyGovernanceRoutingRuleName, stepDecision.MatchedRuleName) + chainSuffix := "" + if matchedRule.ChainRule { + chainSuffix = " [chain_rule=true, continuing]" } + re.logger.Debug("[RoutingEngine] Rule matched! Selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix) + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Rule '%s' [%s] β†’ matched, selected target (weight=%.2f): provider=%s, model=%s, fallbacks=%v%s", matchedRule.Name, matchedRule.CelExpression, matchedTargetWeight, stepDecision.Provider, stepDecision.Model, stepDecision.Fallbacks, chainSuffix)) + + // TERMINATION 2: Rule is terminal (chain_rule=false, the default). + if !matchedRule.ChainRule { + break + } + + // TERMINATION 3: Cycle detection β€” if the next state was already visited, continuing would loop forever. + nextState := fmt.Sprintf("%s|%s", stepDecision.Provider, stepDecision.Model) + if _, seen := visited[nextState]; seen { + re.logger.Debug("[RoutingEngine] Chain cycle detected at step=%d (state=%s already visited), stopping", chainStep, nextState) + ctx.AppendRoutingEngineLog(schemas.RoutingEngineRoutingRule, fmt.Sprintf("Chain cycle detected at step %d (provider=%s, model=%s already visited), stopping. Final resolved: provider=%s, model=%s", chainStep, stepDecision.Provider, stepDecision.Model, stepDecision.Provider, stepDecision.Model)) + break + } + visited[nextState] = struct{}{} + + // Advance context for next chain iteration. + currentProvider = schemas.ModelProvider(stepDecision.Provider) + currentModel = stepDecision.Model } - // No rule matched - return nil decision (caller should use default routing) - re.logger.Debug("[RoutingEngine] No routing rule matched, using default routing") - return nil, nil + if finalDecision == nil { + re.logger.Debug("[RoutingEngine] No routing rule matched, using default routing") + } + return finalDecision, nil } // selectWeightedTarget picks one target from the slice using weighted random selection. diff --git a/plugins/governance/routing_test.go b/plugins/governance/routing_test.go index 045104effc..02aaa6a59a 100644 --- a/plugins/governance/routing_test.go +++ b/plugins/governance/routing_test.go @@ -194,7 +194,7 @@ func TestEvaluateRoutingRules_NilContext(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) _, err = engine.EvaluateRoutingRules(schemas.NewBifrostContext(context.Background(), time.Now()), nil) @@ -207,7 +207,7 @@ func TestEvaluateRoutingRules_NoRulesMatch(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) ctx := &RoutingContext{ @@ -228,7 +228,7 @@ func TestEvaluateRoutingRules_GlobalRuleMatches(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) // Create a global routing rule @@ -279,7 +279,7 @@ func TestEvaluateRoutingRules_MultiTargetDeterministicWithPinnedKey(t *testing.T store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) @@ -348,7 +348,7 @@ func TestEvaluateRoutingRules_ScopePrecedence(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) // Create global rule @@ -412,7 +412,7 @@ func TestEvaluateRoutingRules_PriorityOrdering(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) // Low precedence rule (evaluated second): higher priority number @@ -485,7 +485,7 @@ func TestResolveRoutingWithFallback_RuleMatches(t *testing.T) { QueryParams: map[string]string{}, } - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) decision, err := resolveRoutingWithFallback(bgCtx, ctx, engine) @@ -508,7 +508,7 @@ func TestResolveRoutingWithFallback_NoMatch(t *testing.T) { QueryParams: map[string]string{}, } - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) decision, err := resolveRoutingWithFallback(schemas.NewBifrostContext(context.Background(), time.Now()), ctx, engine) @@ -527,7 +527,7 @@ func TestEvaluateRoutingRules_DisabledRulesIgnored(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) // Create disabled rule @@ -579,7 +579,7 @@ func TestEvaluateRoutingRules_ComplexExpression(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) rule := &configstoreTables.TableRoutingRule{ @@ -623,7 +623,7 @@ func TestEvaluateRoutingRules_NilVirtualKey(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) rule := &configstoreTables.TableRoutingRule{ @@ -659,7 +659,7 @@ func TestEvaluateRoutingRules_MissingHeaderGracefully(t *testing.T) { require.NoError(t, err) bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) - engine, err := NewRoutingEngine(store, NewMockLogger()) + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) require.NoError(t, err) // Create a rule that checks for a header that may not be present @@ -697,6 +697,239 @@ func TestEvaluateRoutingRules_MissingHeaderGracefully(t *testing.T) { assert.Equal(t, "azure", decision.Provider) } +// TestEvaluateRoutingRules_ChainRuleReEvaluation tests that chain_rule=true causes re-evaluation +// with the resolved provider/model fed back into the engine. +func TestEvaluateRoutingRules_ChainRuleReEvaluation(t *testing.T) { + store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) + require.NoError(t, err) + bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) + + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) + require.NoError(t, err) + + // Rule A: matches gpt-4o β†’ routes to gpt-4-turbo, chain_rule=true so re-evaluation continues. + ruleA := &configstoreTables.TableRoutingRule{ + ID: "chain-a", + Name: "Chain Rule A", + CelExpression: "model == 'gpt-4o'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("openai"), Model: bifrost.Ptr("gpt-4-turbo"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 0, + ChainRule: true, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + + // Rule B: matches gpt-4-turbo β†’ routes to azure/gpt-4, terminal (chain_rule=false). + ruleB := &configstoreTables.TableRoutingRule{ + ID: "chain-b", + Name: "Chain Rule B", + CelExpression: "model == 'gpt-4-turbo'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("azure"), Model: bifrost.Ptr("gpt-4"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 1, + ChainRule: false, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleB)) + + ctx := &RoutingContext{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Headers: map[string]string{}, + QueryParams: map[string]string{}, + } + + decision, err := engine.EvaluateRoutingRules(bgCtx, ctx) + require.NoError(t, err) + require.NotNil(t, decision) + + // Rule A matched first, but chain_rule=true caused re-evaluation. + // Rule B then matched the updated model (gpt-4-turbo) and produced the final result. + assert.Equal(t, "azure", decision.Provider) + assert.Equal(t, "gpt-4", decision.Model) + assert.Equal(t, "chain-b", decision.MatchedRuleID) +} + +// TestEvaluateRoutingRules_TerminalRuleStopsChain tests that a terminal rule (chain_rule=false) +// halts the chaining loop immediately without re-evaluation. +func TestEvaluateRoutingRules_TerminalRuleStopsChain(t *testing.T) { + store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) + require.NoError(t, err) + bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) + + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) + require.NoError(t, err) + + // Rule A: matches gpt-4o β†’ routes to gpt-4-turbo, terminal (chain_rule=false). + ruleA := &configstoreTables.TableRoutingRule{ + ID: "terminal-a", + Name: "Terminal Rule A", + CelExpression: "model == 'gpt-4o'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("openai"), Model: bifrost.Ptr("gpt-4-turbo"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 0, + ChainRule: false, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + + // Rule B: would match gpt-4-turbo, but should never be reached because Rule A is terminal. + ruleB := &configstoreTables.TableRoutingRule{ + ID: "terminal-b", + Name: "Terminal Rule B", + CelExpression: "model == 'gpt-4-turbo'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("azure"), Model: bifrost.Ptr("gpt-4"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 1, + ChainRule: false, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleB)) + + ctx := &RoutingContext{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Headers: map[string]string{}, + QueryParams: map[string]string{}, + } + + decision, err := engine.EvaluateRoutingRules(bgCtx, ctx) + require.NoError(t, err) + require.NotNil(t, decision) + + // Only Rule A should have matched; chain stopped immediately at terminal rule. + assert.Equal(t, "openai", decision.Provider) + assert.Equal(t, "gpt-4-turbo", decision.Model) + assert.Equal(t, "terminal-a", decision.MatchedRuleID) +} + +// TestEvaluateRoutingRules_ConvergenceStopsChain tests that the cycle-detection mechanism stops +// the chain when a chain_rule=true rule resolves to a provider/model already visited (no-op loop). +func TestEvaluateRoutingRules_ConvergenceStopsChain(t *testing.T) { + store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) + require.NoError(t, err) + bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) + + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(10)) + require.NoError(t, err) + + // Rule A: chain_rule=true but resolves back to the initial provider/model β€” creates a cycle. + ruleA := &configstoreTables.TableRoutingRule{ + ID: "converge-a", + Name: "Convergence Rule A", + CelExpression: "model == 'gpt-4o'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("openai"), Model: bifrost.Ptr("gpt-4o"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 0, + ChainRule: true, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + + ctx := &RoutingContext{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Headers: map[string]string{}, + QueryParams: map[string]string{}, + } + + decision, err := engine.EvaluateRoutingRules(bgCtx, ctx) + require.NoError(t, err) + require.NotNil(t, decision) + + // Cycle detected after the first match; the last matched decision (openai/gpt-4o) is returned. + assert.Equal(t, "openai", decision.Provider) + assert.Equal(t, "gpt-4o", decision.Model) + assert.Equal(t, "converge-a", decision.MatchedRuleID) +} + +// TestEvaluateRoutingRules_MaxDepthCutoff tests that the chain stops once chainMaxDepth is reached, +// returning the last successfully resolved decision rather than continuing further. +func TestEvaluateRoutingRules_MaxDepthCutoff(t *testing.T) { + store, err := NewLocalGovernanceStore(context.Background(), NewMockLogger(), nil, &configstore.GovernanceConfig{}, nil) + require.NoError(t, err) + bgCtx := schemas.NewBifrostContext(context.Background(), time.Now()) + + // Use maxDepth=2: steps 0 and 1 are allowed; step 2 is cut off before any rule is evaluated. + engine, err := NewRoutingEngine(store, NewMockLogger(), schemas.Ptr(2)) + require.NoError(t, err) + + // Rule A: gpt-4o β†’ gpt-4-turbo, chain continues. + ruleA := &configstoreTables.TableRoutingRule{ + ID: "depth-a", + Name: "Depth Rule A", + CelExpression: "model == 'gpt-4o'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("openai"), Model: bifrost.Ptr("gpt-4-turbo"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 0, + ChainRule: true, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleA)) + + // Rule B: gpt-4-turbo β†’ azure/gpt-4, chain continues (would proceed to step 2 if depth allowed). + ruleB := &configstoreTables.TableRoutingRule{ + ID: "depth-b", + Name: "Depth Rule B", + CelExpression: "model == 'gpt-4-turbo'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("azure"), Model: bifrost.Ptr("gpt-4"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 1, + ChainRule: true, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleB)) + + // Rule C: gpt-4 β†’ anthropic/claude-3, would match at step 2 but max depth is 2. + ruleC := &configstoreTables.TableRoutingRule{ + ID: "depth-c", + Name: "Depth Rule C", + CelExpression: "model == 'gpt-4'", + Targets: []configstoreTables.TableRoutingTarget{ + {Provider: bifrost.Ptr("anthropic"), Model: bifrost.Ptr("claude-3"), Weight: 1.0}, + }, + Enabled: true, + Scope: "global", + Priority: 2, + ChainRule: false, + } + require.NoError(t, store.UpdateRoutingRuleInMemory(ruleC)) + + ctx := &RoutingContext{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Headers: map[string]string{}, + QueryParams: map[string]string{}, + } + + decision, err := engine.EvaluateRoutingRules(bgCtx, ctx) + require.NoError(t, err) + require.NotNil(t, decision) + + // Step 0: Rule A matched β†’ openai/gpt-4-turbo (finalDecision updated) + // Step 1: Rule B matched β†’ azure/gpt-4 (finalDecision updated) + // Step 2: chainStep (2) >= maxDepth (2) β†’ cut off before Rule C can match + // Final result is the last successful decision: azure/gpt-4 + assert.Equal(t, "azure", decision.Provider) + assert.Equal(t, "gpt-4", decision.Model) + assert.Equal(t, "depth-b", decision.MatchedRuleID) +} + // TestCompileAndCacheProgram_ValidExpression_Routing tests compiling and caching a valid CEL expression func TestCompileAndCacheProgram_ValidExpression_Routing(t *testing.T) { ctx := context.Background() diff --git a/plugins/governance/store.go b/plugins/governance/store.go index f9cc7dae0c..4775c71bd4 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -195,12 +195,19 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { if vk == nil { return } - if vk.BudgetID != nil { - if liveBudget, exists := gs.budgets.Load(*vk.BudgetID); exists && liveBudget != nil { - if b, ok := liveBudget.(*configstoreTables.TableBudget); ok { - vk.Budget = b + // Cross-reference live budget/rate limit from standalone maps + // (usage updates clone into budgets/rateLimits maps, so embedded pointers go stale) + // Hydrate multi-budgets from live sync.Map + if len(vk.Budgets) > 0 { + liveBudgets := make([]configstoreTables.TableBudget, 0, len(vk.Budgets)) + for _, b := range vk.Budgets { + if lb, exists := gs.budgets.Load(b.ID); exists && lb != nil { + if budget, ok := lb.(*configstoreTables.TableBudget); ok { + liveBudgets = append(liveBudgets, *budget) + } } } + vk.Budgets = liveBudgets } if vk.RateLimitID != nil { if liveRL, exists := gs.rateLimits.Load(*vk.RateLimitID); exists && liveRL != nil { @@ -213,12 +220,17 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { configs := make([]configstoreTables.TableVirtualKeyProviderConfig, len(vk.ProviderConfigs)) copy(configs, vk.ProviderConfigs) for i := range configs { - if configs[i].BudgetID != nil { - if liveBudget, exists := gs.budgets.Load(*configs[i].BudgetID); exists && liveBudget != nil { - if b, ok := liveBudget.(*configstoreTables.TableBudget); ok { - configs[i].Budget = b + // Hydrate provider config multi-budgets + if len(configs[i].Budgets) > 0 { + liveBudgets := make([]configstoreTables.TableBudget, 0, len(configs[i].Budgets)) + for _, b := range configs[i].Budgets { + if lb, exists := gs.budgets.Load(b.ID); exists && lb != nil { + if budget, ok := lb.(*configstoreTables.TableBudget); ok { + liveBudgets = append(liveBudgets, *budget) + } } } + configs[i].Budgets = liveBudgets } if configs[i].RateLimitID != nil { if liveRL, exists := gs.rateLimits.Load(*configs[i].RateLimitID); exists && liveRL != nil { @@ -373,7 +385,7 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { return true // continue iteration }) routingRules := make(map[string]*configstoreTables.TableRoutingRule) - gs.routingRules.Range(func(key, value interface{}) bool { + gs.routingRules.Range(func(key, value any) bool { rules, ok := value.([]*configstoreTables.TableRoutingRule) if !ok || rules == nil { return true // continue @@ -387,7 +399,7 @@ func (gs *LocalGovernanceStore) GetGovernanceData() *GovernanceData { return true // continue iteration }) var modelConfigsList []*configstoreTables.TableModelConfig - gs.modelConfigs.Range(func(key, value interface{}) bool { + gs.modelConfigs.Range(func(key, value any) bool { mc, ok := value.(*configstoreTables.TableModelConfig) if !ok || mc == nil { return true // continue @@ -1527,10 +1539,34 @@ func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) var shouldReset bool var newLastReset time.Time - if budget.CalendarAligned { + // Check if the owning VK has calendar alignment enabled + // virtualKeys map is keyed by VK value (not ID), so we scan to find by VirtualKeyID + calendarAligned := false + if budget.VirtualKeyID != nil { + gs.virtualKeys.Range(func(_, v interface{}) bool { + if vk, ok := v.(*configstoreTables.TableVirtualKey); ok && vk != nil && vk.ID == *budget.VirtualKeyID { + calendarAligned = vk.CalendarAligned + return false // stop + } + return true + }) + } else if budget.ProviderConfigID != nil { + // Provider config budgets: look up the VK that owns this provider config + gs.virtualKeys.Range(func(_, v interface{}) bool { + if vk, ok := v.(*configstoreTables.TableVirtualKey); ok && vk != nil { + for _, pc := range vk.ProviderConfigs { + if pc.ID == *budget.ProviderConfigID { + calendarAligned = vk.CalendarAligned + return false // stop + } + } + } + return true + }) + } + + if calendarAligned { // Calendar-aligned: reset when we've entered a genuinely new calendar period. - // This avoids the double-reset bug with rolling durations in months with - // more days than ParseDuration approximates (e.g. 31-day months with "1M" = 30 days). currentPeriodStart := configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, now) if currentPeriodStart.After(budget.LastReset) { shouldReset = true @@ -2102,33 +2138,17 @@ func (gs *LocalGovernanceStore) loadFromConfigMemory(ctx context.Context, config } } - for i := range budgets { - if vk.BudgetID != nil && budgets[i].ID == *vk.BudgetID { - vk.Budget = &budgets[i] - } - } - for i := range rateLimits { if vk.RateLimitID != nil && rateLimits[i].ID == *vk.RateLimitID { vk.RateLimit = &rateLimits[i] } } - // Populate provider config relationships with budgets and rate limits + // Populate provider config relationships with rate limits if vk.ProviderConfigs != nil { for j := range vk.ProviderConfigs { pc := &vk.ProviderConfigs[j] - // Populate budget - if pc.BudgetID != nil { - for k := range budgets { - if budgets[k].ID == *pc.BudgetID { - pc.Budget = &budgets[k] - break - } - } - } - // Populate rate limit if pc.RateLimitID != nil { for k := range rateLimits { @@ -2250,7 +2270,7 @@ func (gs *LocalGovernanceStore) rebuildInMemoryStructures(ctx context.Context, c if rules, ok := value.([]*configstoreTables.TableRoutingRule); ok { for _, rule := range rules { if _, err := gs.GetRoutingProgram(rule); err != nil { - gs.logger.Warn("Failed to pre-compile routing program for rule %s: %v", rule.ID, err) + gs.logger.Warn("Failed to pre-compile routing program for rule %s: %v", rule.Name, err) } } } @@ -2373,22 +2393,36 @@ func (gs *LocalGovernanceStore) collectBudgetsFromHierarchy(vk *configstoreTable var budgetNames []string // Collect all budgets in hierarchy order using lock-free sync.Map access (Provider Configs β†’ VK β†’ Team β†’ Customer) + seen := make(map[string]bool) for _, pc := range vk.ProviderConfigs { - if pc.BudgetID != nil && pc.Provider == string(requestedProvider) { - if budgetValue, exists := gs.budgets.Load(*pc.BudgetID); exists && budgetValue != nil { + if pc.Provider != string(requestedProvider) { + continue + } + // Multi-budgets + for _, b := range pc.Budgets { + if seen[b.ID] { + continue + } + if budgetValue, exists := gs.budgets.Load(b.ID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { budgets = append(budgets, budget) budgetNames = append(budgetNames, pc.Provider) + seen[budget.ID] = true } } } } - if vk.BudgetID != nil { - if budgetValue, exists := gs.budgets.Load(*vk.BudgetID); exists && budgetValue != nil { + // VK-level multi-budgets + for _, b := range vk.Budgets { + if seen[b.ID] { + continue + } + if budgetValue, exists := gs.budgets.Load(b.ID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { budgets = append(budgets, budget) budgetNames = append(budgetNames, "VK") + seen[budget.ID] = true } } } @@ -2477,9 +2511,9 @@ func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.T return // Nothing to create } - // Create associated budget if exists - if vk.Budget != nil { - gs.budgets.Store(vk.Budget.ID, vk.Budget) + // Store budgets + for i := range vk.Budgets { + gs.budgets.Store(vk.Budgets[i].ID, &vk.Budgets[i]) } // Create associated rate limit if exists @@ -2490,8 +2524,8 @@ func (gs *LocalGovernanceStore) CreateVirtualKeyInMemory(vk *configstoreTables.T // Create provider config budgets and rate limits if they exist if vk.ProviderConfigs != nil { for _, pc := range vk.ProviderConfigs { - if pc.Budget != nil { - gs.budgets.Store(pc.Budget.ID, pc.Budget) + for i := range pc.Budgets { + gs.budgets.Store(pc.Budgets[i].ID, &pc.Budgets[i]) } if pc.RateLimit != nil { gs.rateLimits.Store(pc.RateLimit.ID, pc.RateLimit) @@ -2518,23 +2552,37 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.T // Create clone to avoid modifying the original clone := *vk - // Update Budget for VK in memory store - if clone.Budget != nil { - // Preserve existing usage from memory when updating budget config - // The usage tracker maintains current usage in memory, and we only want to update - // the configuration fields (max_limit, reset_duration) from the database - if existingBudgetValue, exists := gs.budgets.Load(clone.Budget.ID); exists && existingBudgetValue != nil { + + // Collect all incoming budget IDs across VK + provider configs to avoid + // deleting a budget that was moved between VK-level and PC-level in one update. + allNewBudgetIDs := make(map[string]bool) + for i := range clone.Budgets { + allNewBudgetIDs[clone.Budgets[i].ID] = true + } + for i := range clone.ProviderConfigs { + for j := range clone.ProviderConfigs[i].Budgets { + allNewBudgetIDs[clone.ProviderConfigs[i].Budgets[j].ID] = true + } + } + + // Update multi-budgets for VK + for i := range clone.Budgets { + // Preserve existing usage from memory + if existingBudgetValue, exists := gs.budgets.Load(clone.Budgets[i].ID); exists && existingBudgetValue != nil { if existingBudget, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && existingBudget != nil { - // Preserve current usage and last reset time from existing in-memory budget - clone.Budget.CurrentUsage = existingBudget.CurrentUsage - clone.Budget.LastReset = existingBudget.LastReset + clone.Budgets[i].CurrentUsage = existingBudget.CurrentUsage + clone.Budgets[i].LastReset = existingBudget.LastReset } } - gs.budgets.Store(clone.Budget.ID, clone.Budget) - } else if existingVK.Budget != nil { - // Budget was removed from the virtual key, delete it from memory - gs.budgets.Delete(existingVK.Budget.ID) + gs.budgets.Store(clone.Budgets[i].ID, &clone.Budgets[i]) } + // Delete removed multi-budgets + for _, oldBudget := range existingVK.Budgets { + if !allNewBudgetIDs[oldBudget.ID] { + gs.budgets.Delete(oldBudget.ID) + } + } + if clone.RateLimit != nil { // Preserve existing usage from memory when updating rate limit config // The usage tracker maintains current usage in memory, and we only want to update @@ -2584,22 +2632,23 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyInMemory(vk *configstoreTables.T clone.ProviderConfigs[i].RateLimit = nil } } - // Update Budget for provider config in memory store - if pc.Budget != nil { - // Preserve existing usage from memory when updating provider config budget - if existingBudgetValue, exists := gs.budgets.Load(pc.Budget.ID); exists && existingBudgetValue != nil { + // Update multi-budgets for provider config + for j := range clone.ProviderConfigs[i].Budgets { + b := &clone.ProviderConfigs[i].Budgets[j] + if existingBudgetValue, exists := gs.budgets.Load(b.ID); exists && existingBudgetValue != nil { if existingBudget, ok := existingBudgetValue.(*configstoreTables.TableBudget); ok && existingBudget != nil { - // Preserve current usage and last reset time from existing in-memory budget - clone.ProviderConfigs[i].Budget.CurrentUsage = existingBudget.CurrentUsage - clone.ProviderConfigs[i].Budget.LastReset = existingBudget.LastReset + b.CurrentUsage = existingBudget.CurrentUsage + b.LastReset = existingBudget.LastReset } } - gs.budgets.Store(clone.ProviderConfigs[i].Budget.ID, clone.ProviderConfigs[i].Budget) - } else { - // Budget was removed from provider config, delete it from memory if it existed - if existingPC, exists := existingProviderConfigs[pc.ID]; exists && existingPC.Budget != nil { - gs.budgets.Delete(existingPC.Budget.ID) - clone.ProviderConfigs[i].Budget = nil + gs.budgets.Store(b.ID, b) + } + // Delete removed multi-budgets for this provider config + if existingPC, exists := existingProviderConfigs[pc.ID]; exists { + for _, oldBudget := range existingPC.Budgets { + if !allNewBudgetIDs[oldBudget.ID] { + gs.budgets.Delete(oldBudget.ID) + } } } } @@ -2625,9 +2674,9 @@ func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(vkID string) { } if vk.ID == vkID { - // Delete associated budget if exists - if vk.BudgetID != nil { - gs.budgets.Delete(*vk.BudgetID) + // Delete budgets + for _, b := range vk.Budgets { + gs.budgets.Delete(b.ID) } // Delete associated rate limit if exists @@ -2638,8 +2687,8 @@ func (gs *LocalGovernanceStore) DeleteVirtualKeyInMemory(vkID string) { // Delete provider config budgets and rate limits if vk.ProviderConfigs != nil { for _, pc := range vk.ProviderConfigs { - if pc.BudgetID != nil { - gs.budgets.Delete(*pc.BudgetID) + for _, b := range pc.Budgets { + gs.budgets.Delete(b.ID) } if pc.RateLimitID != nil { gs.rateLimits.Delete(*pc.RateLimitID) @@ -3165,18 +3214,22 @@ func (gs *LocalGovernanceStore) updateBudgetReferences(resetBudget *configstoreT needsUpdate := false clone := *vk - // Check VK-level budget - if vk.BudgetID != nil && *vk.BudgetID == budgetID { - clone.Budget = resetBudget - needsUpdate = true + // Check VK-level budgets + for i, b := range clone.Budgets { + if b.ID == budgetID { + clone.Budgets[i] = *resetBudget + needsUpdate = true + } } // Check provider config budgets if vk.ProviderConfigs != nil { - for i, pc := range clone.ProviderConfigs { - if pc.BudgetID != nil && *pc.BudgetID == budgetID { - clone.ProviderConfigs[i].Budget = resetBudget - needsUpdate = true + for i := range clone.ProviderConfigs { + for j, b := range clone.ProviderConfigs[i].Budgets { + if b.ID == budgetID { + clone.ProviderConfigs[i].Budgets[j] = *resetBudget + needsUpdate = true + } } } } @@ -3638,9 +3691,9 @@ func (gs *LocalGovernanceStore) GetBudgetAndRateLimitStatus(ctx context.Context, } } } - // Get budget status - if pc.BudgetID != nil { - if budgetValue, ok := gs.budgets.Load(*pc.BudgetID); ok && budgetValue != nil { + // Get budget status from multi-budgets + for _, b := range pc.Budgets { + if budgetValue, ok := gs.budgets.Load(b.ID); ok && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { baseline, exists := budgetBaselines[budget.ID] if !exists { @@ -3731,7 +3784,7 @@ func (gs *LocalGovernanceStore) UpdateRoutingRuleInMemory(rule *configstoreTable // Recompile the program immediately to update cache with fresh compilation if _, err := gs.GetRoutingProgram(rule); err != nil { - gs.logger.Warn("Failed to recompile routing program for rule %s: %v", rule.ID, err) + gs.logger.Warn("Failed to recompile routing program for rule %s: %v", rule.Name, err) } return nil diff --git a/plugins/governance/store_test.go b/plugins/governance/store_test.go index 0640052d56..21a54e1bd5 100644 --- a/plugins/governance/store_test.go +++ b/plugins/governance/store_test.go @@ -198,16 +198,397 @@ func TestGovernanceStore_CheckBudget_HierarchyValidation(t *testing.T) { // Test: If VK budget exceeds limit, should fail // Update the budget directly in the budgets map (since UpdateVirtualKeyInMemory preserves usage) - if vk.BudgetID != nil { - if budgetValue, exists := store.budgets.Load(*vk.BudgetID); exists && budgetValue != nil { + if len(vk.Budgets) > 0 { + budgetID := vk.Budgets[0].ID + if budgetValue, exists := store.budgets.Load(budgetID); exists && budgetValue != nil { if budget, ok := budgetValue.(*configstoreTables.TableBudget); ok && budget != nil { budget.CurrentUsage = 100.0 - store.budgets.Store(*vk.BudgetID, budget) + store.budgets.Store(budgetID, budget) } } } err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) - assert.Error(t, err, "Should fail when VK budget exceeds limit") + require.Error(t, err, "Should fail when VK budget exceeds limit") +} + +// TestGovernanceStore_MultiBudget_AllUnderLimit tests that requests pass when all budgets are under their limits +func TestGovernanceStore_MultiBudget_AllUnderLimit(t *testing.T) { + logger := NewMockLogger() + + // Create VK with hourly ($10) and daily ($100) budgets + hourlyBudget := buildBudgetWithUsage("hourly", 10.0, 5.0, "1h") + dailyBudget := buildBudgetWithUsage("daily", 100.0, 40.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + // Add provider config so the resolver allows the provider + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + assert.NoError(t, err, "Should pass when all budgets are under limit") +} + +// TestGovernanceStore_MultiBudget_SmallBudgetExceeded tests that request is blocked when the smaller budget exceeds its limit +func TestGovernanceStore_MultiBudget_SmallBudgetExceeded(t *testing.T) { + logger := NewMockLogger() + + // Hourly at limit, daily still has room + hourlyBudget := buildBudgetWithUsage("hourly", 10.0, 10.0, "1h") + dailyBudget := buildBudgetWithUsage("daily", 100.0, 40.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + require.Error(t, err, "Should fail when hourly budget is exceeded even though daily is fine") + assert.Contains(t, err.Error(), "budget exceeded") +} + +// TestGovernanceStore_MultiBudget_LargeBudgetExceeded tests that request is blocked when only the larger budget exceeds +func TestGovernanceStore_MultiBudget_LargeBudgetExceeded(t *testing.T) { + logger := NewMockLogger() + + // Hourly has room, but daily is at limit + hourlyBudget := buildBudgetWithUsage("hourly", 10.0, 3.0, "1h") + dailyBudget := buildBudgetWithUsage("daily", 100.0, 100.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + require.Error(t, err, "Should fail when daily budget is exceeded even though hourly is fine") + assert.Contains(t, err.Error(), "budget exceeded") +} + +// TestGovernanceStore_MultiBudget_UsageUpdatesAllBudgets tests that usage updates are applied to every budget in the hierarchy +func TestGovernanceStore_MultiBudget_UsageUpdatesAllBudgets(t *testing.T) { + logger := NewMockLogger() + + hourlyBudget := buildBudget("hourly", 10.0, "1h") + dailyBudget := buildBudget("daily", 100.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + + // Simulate a $3.50 request + err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 3.50) + require.NoError(t, err) + + // Both budgets should reflect the cost + hourlyVal, exists := store.budgets.Load("hourly") + require.True(t, exists) + assert.InDelta(t, 3.50, hourlyVal.(*configstoreTables.TableBudget).CurrentUsage, 0.01, "Hourly budget should reflect usage") + + dailyVal, exists := store.budgets.Load("daily") + require.True(t, exists) + assert.InDelta(t, 3.50, dailyVal.(*configstoreTables.TableBudget).CurrentUsage, 0.01, "Daily budget should reflect usage") + + // Second request: $7.00 β€” should push hourly over limit + err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 7.00) + require.NoError(t, err) + + hourlyVal, _ = store.budgets.Load("hourly") + assert.InDelta(t, 10.50, hourlyVal.(*configstoreTables.TableBudget).CurrentUsage, 0.01, "Hourly budget should accumulate") + + dailyVal, _ = store.budgets.Load("daily") + assert.InDelta(t, 10.50, dailyVal.(*configstoreTables.TableBudget).CurrentUsage, 0.01, "Daily budget should accumulate") + + // Now CheckBudget should fail (hourly exceeded) + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + require.Error(t, err, "Should fail after usage exceeds hourly budget") + assert.Contains(t, err.Error(), "budget exceeded") +} + +// TestGovernanceStore_MultiBudget_ProviderConfigBudgets tests that provider-config-level multi-budgets are enforced +func TestGovernanceStore_MultiBudget_ProviderConfigBudgets(t *testing.T) { + logger := NewMockLogger() + + // Provider-level budgets: hourly $5 (exceeded), daily $50 (ok) + pcHourly := buildBudgetWithUsage("pc-hourly", 5.0, 5.0, "1h") + pcDaily := buildBudgetWithUsage("pc-daily", 50.0, 10.0, "1d") + + pc := buildProviderConfigWithBudgets("openai", []string{"*"}, + []configstoreTables.TableBudget{*pcHourly, *pcDaily}) + + vk := buildVirtualKeyWithProviders("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableVirtualKeyProviderConfig{pc}) + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*pcHourly, *pcDaily}, + }, nil) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + require.Error(t, err, "Should fail when provider config hourly budget is exceeded") + assert.Contains(t, err.Error(), "budget exceeded") +} + +// TestGovernanceStore_MultiBudget_VKAndProviderConfigCombined tests budgets at both VK and provider config levels +func TestGovernanceStore_MultiBudget_VKAndProviderConfigCombined(t *testing.T) { + logger := NewMockLogger() + + // VK-level budgets: all under limit + vkMonthly := buildBudgetWithUsage("vk-monthly", 1000.0, 200.0, "1M") + + // Provider-config-level budgets: hourly at limit + pcHourly := buildBudgetWithUsage("pc-hourly", 5.0, 5.0, "1h") + + pc := buildProviderConfigWithBudgets("openai", []string{"*"}, + []configstoreTables.TableBudget{*pcHourly}) + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*vkMonthly}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{pc} + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*vkMonthly, *pcHourly}, + }, nil) + require.NoError(t, err) + + vk, _ = store.GetVirtualKey("sk-bf-test") + + // Provider config budget exceeded β†’ should block even though VK budget is fine + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + require.Error(t, err, "Should fail: provider config budget exceeded even though VK budget is fine") + assert.Contains(t, err.Error(), "budget exceeded") +} + +// TestGovernanceStore_MultiBudget_ResolverBlocksOnBudgetExceeded tests that the full resolver flow blocks when any budget is exceeded +func TestGovernanceStore_MultiBudget_ResolverBlocksOnBudgetExceeded(t *testing.T) { + logger := NewMockLogger() + + // Two VK-level budgets: hourly at limit, daily has room + hourlyBudget := buildBudgetWithUsage("hourly", 10.0, 10.0, "1h") + dailyBudget := buildBudgetWithUsage("daily", 100.0, 30.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, nil, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) + assertDecision(t, DecisionBudgetExceeded, result) + assert.Contains(t, result.Reason, "budget exceeded") +} + +// TestGovernanceStore_MultiBudget_ResolverAllowsUnderLimit tests that the full resolver flow allows requests when all budgets are under limit +func TestGovernanceStore_MultiBudget_ResolverAllowsUnderLimit(t *testing.T) { + logger := NewMockLogger() + + hourlyBudget := buildBudgetWithUsage("hourly", 10.0, 5.0, "1h") + dailyBudget := buildBudgetWithUsage("daily", 100.0, 30.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, nil, logger) + ctx := &schemas.BifrostContext{} + + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) + assertDecision(t, DecisionAllow, result) +} + +// TestGovernanceStore_MultiBudget_UsageDrivesBlockAfterRequests tests the full lifecycle: +// start under limit β†’ accumulate usage β†’ eventually hit a budget β†’ get blocked +func TestGovernanceStore_MultiBudget_UsageDrivesBlockAfterRequests(t *testing.T) { + logger := NewMockLogger() + + // Tight hourly ($2), generous daily ($100) + hourlyBudget := buildBudget("hourly", 2.0, "1h") + dailyBudget := buildBudget("daily", 100.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*hourlyBudget, *dailyBudget}, + }, nil) + require.NoError(t, err) + + resolver := NewBudgetResolver(store, nil, logger) + + // Request 1: $0.80 β€” both budgets fine + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 0.80) + require.NoError(t, err) + + ctx := &schemas.BifrostContext{} + result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) + assertDecision(t, DecisionAllow, result) + + // Request 2: $0.80 β€” still fine ($1.60 total) + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 0.80) + require.NoError(t, err) + + ctx = &schemas.BifrostContext{} + result = resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) + assertDecision(t, DecisionAllow, result) + + // Request 3: $0.80 β€” pushes hourly to $2.40 > $2.00 limit β†’ blocked + vk, _ = store.GetVirtualKey("sk-bf-test") + err = store.UpdateVirtualKeyBudgetUsageInMemory(context.Background(), vk, schemas.OpenAI, 0.80) + require.NoError(t, err) + + ctx = &schemas.BifrostContext{} + result = resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest, false) + assertDecision(t, DecisionBudgetExceeded, result) + assert.Contains(t, result.Reason, "budget exceeded") + + // Verify daily budget is still under limit + dailyVal, exists := store.budgets.Load("daily") + require.True(t, exists) + assert.InDelta(t, 2.40, dailyVal.(*configstoreTables.TableBudget).CurrentUsage, 0.01, + "Daily budget should be at $2.40, well under $100 limit") +} + +// TestGovernanceStore_MultiBudget_CalendarAligned tests that calendar-aligned budgets are stored and retrievable +func TestGovernanceStore_MultiBudget_CalendarAligned(t *testing.T) { + logger := NewMockLogger() + + // Calendar alignment is a VK-level setting β€” budgets don't have it + dailyBudget := &configstoreTables.TableBudget{ + ID: "daily-cal", + MaxLimit: 50.0, + CurrentUsage: 10.0, + ResetDuration: "1d", + LastReset: time.Now(), + } + monthlyBudget := &configstoreTables.TableBudget{ + ID: "monthly-cal", + MaxLimit: 1000.0, + CurrentUsage: 200.0, + ResetDuration: "1M", + LastReset: time.Now(), + } + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*dailyBudget, *monthlyBudget}) + vk.CalendarAligned = true // VK-level setting applies to all budgets + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{ + VirtualKeys: []configstoreTables.TableVirtualKey{*vk}, + Budgets: []configstoreTables.TableBudget{*dailyBudget, *monthlyBudget}, + }, nil) + require.NoError(t, err) + + // Verify VK-level calendar_aligned is set + vk, _ = store.GetVirtualKey("sk-bf-test") + assert.True(t, vk.CalendarAligned, "VK should have calendar_aligned=true") + + // Both under limit β€” should pass + err = store.CheckBudget(context.Background(), vk, &EvaluationRequest{Provider: schemas.OpenAI}, nil) + assert.NoError(t, err) +} + +// TestGovernanceStore_MultiBudget_InMemoryCreateAndDelete tests CreateVirtualKeyInMemory and DeleteVirtualKeyInMemory +// properly store and clean up multi-budget entries +func TestGovernanceStore_MultiBudget_InMemoryCreateAndDelete(t *testing.T) { + logger := NewMockLogger() + + store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) + require.NoError(t, err) + + b1 := buildBudget("b1", 10.0, "1h") + b2 := buildBudget("b2", 100.0, "1d") + + vk := buildVirtualKeyWithMultiBudgets("vk1", "sk-bf-test", "Test VK", + []configstoreTables.TableBudget{*b1, *b2}) + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } + + // Create + store.CreateVirtualKeyInMemory(vk) + + _, exists := store.budgets.Load("b1") + assert.True(t, exists, "Budget b1 should be in memory after create") + _, exists = store.budgets.Load("b2") + assert.True(t, exists, "Budget b2 should be in memory after create") + + retrieved, found := store.GetVirtualKey("sk-bf-test") + require.True(t, found) + assert.Len(t, retrieved.Budgets, 2, "VK should have 2 budgets") + + // Delete + store.DeleteVirtualKeyInMemory("vk1") + + _, exists = store.budgets.Load("b1") + assert.False(t, exists, "Budget b1 should be removed after delete") + _, exists = store.budgets.Load("b2") + assert.False(t, exists, "Budget b2 should be removed after delete") + + _, found = store.GetVirtualKey("sk-bf-test") + assert.False(t, found, "VK should not be found after delete") } // TestGovernanceStore_UpdateRateLimitUsage_TokensAndRequests tests atomic rate limit usage updates @@ -319,9 +700,9 @@ func TestGovernanceStore_ResetExpiredBudgets(t *testing.T) { // Retrieve the updated VK to check budget changes updatedVK, _ := store.GetVirtualKey("sk-bf-test") require.NotNil(t, updatedVK) - require.NotNil(t, updatedVK.Budget) + require.True(t, len(updatedVK.Budgets) > 0, "VK should have budgets") - assert.Equal(t, 0.0, updatedVK.Budget.CurrentUsage, "Budget usage should be reset") + assert.Equal(t, 0.0, updatedVK.Budgets[0].CurrentUsage, "Budget usage should be reset") } // TestGovernanceStore_GetAllBudgets tests retrieving all budgets diff --git a/plugins/governance/test_utils.go b/plugins/governance/test_utils.go index a6da347cbf..b7ad6dbad3 100644 --- a/plugins/governance/test_utils.go +++ b/plugins/governance/test_utils.go @@ -15,7 +15,6 @@ import ( configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/modelcatalog" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // MockLogger implements schemas.Logger for testing @@ -89,9 +88,13 @@ func buildVirtualKey(id, value, name string, isActive bool) *configstoreTables.T func buildVirtualKeyWithBudget(id, value, name string, budget *configstoreTables.TableBudget) *configstoreTables.TableVirtualKey { vk := buildVirtualKey(id, value, name, true) - vk.Budget = budget - budgetID := budget.ID - vk.BudgetID = &budgetID + vkID := id + budget.VirtualKeyID = &vkID + vk.Budgets = []configstoreTables.TableBudget{*budget} + // Add a default provider config so the resolver doesn't block at provider check + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } return vk } @@ -100,6 +103,10 @@ func buildVirtualKeyWithRateLimit(id, value, name string, rateLimit *configstore vk.RateLimit = rateLimit rateLimitID := rateLimit.ID vk.RateLimitID = &rateLimitID + // Add a default provider config so the resolver doesn't block at provider check + vk.ProviderConfigs = []configstoreTables.TableVirtualKeyProviderConfig{ + buildProviderConfig("openai", []string{"*"}), + } return vk } @@ -189,11 +196,26 @@ func buildProviderConfig(provider string, allowedModels []string) configstoreTab AllowedModels: allowedModels, Weight: bifrost.Ptr(1.0), RateLimit: nil, - Budget: nil, Keys: []configstoreTables.TableKey{}, } } +func buildProviderConfigWithBudgets(provider string, allowedModels []string, budgets []configstoreTables.TableBudget) configstoreTables.TableVirtualKeyProviderConfig { + pc := buildProviderConfig(provider, allowedModels) + pc.Budgets = budgets + return pc +} + +func buildVirtualKeyWithMultiBudgets(id, value, name string, budgets []configstoreTables.TableBudget) *configstoreTables.TableVirtualKey { + vk := buildVirtualKey(id, value, name, true) + for i := range budgets { + vkID := id + budgets[i].VirtualKeyID = &vkID + } + vk.Budgets = budgets + return vk +} + func buildProviderConfigWithRateLimit(provider string, allowedModels []string, rateLimit *configstoreTables.TableRateLimit) configstoreTables.TableVirtualKeyProviderConfig { pc := buildProviderConfig(provider, allowedModels) pc.RateLimit = rateLimit @@ -221,15 +243,6 @@ func assertRateLimitInfo(t *testing.T, result *EvaluationResult) { assert.NotNil(t, result.RateLimitInfo, "RateLimitInfo should be present in result") } -func requireNoError(t *testing.T, err error, msg string) { - t.Helper() - require.NoError(t, err, msg) -} - -func requireError(t *testing.T, err error, msg string) { - t.Helper() - require.Error(t, err, msg) -} func buildModelConfig(id, modelName string, provider *string, budget *configstoreTables.TableBudget, rateLimit *configstoreTables.TableRateLimit) *configstoreTables.TableModelConfig { mc := &configstoreTables.TableModelConfig{ diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go index f33abb3aa1..b15b6fb02e 100644 --- a/plugins/governance/utils.go +++ b/plugins/governance/utils.go @@ -34,7 +34,7 @@ func ParseVirtualKeyFromFastHTTPRequest(req *fasthttp.RequestCtx) *string { return bifrost.Ptr(xAPIKey) } xGoogleAPIKey := string(req.Request.Header.Peek("x-goog-api-key")) - if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { + if xGoogleAPIKey != "" && strings.HasPrefix(strings.ToLower(xGoogleAPIKey), VirtualKeyPrefix) { return bifrost.Ptr(xGoogleAPIKey) } return nil @@ -99,9 +99,9 @@ func (p *GovernancePlugin) filterModelsForVirtualKey( return []schemas.Model{} // VK not found, return empty list } - // Empty ProviderConfigs means all models are allowed + // Empty ProviderConfigs means no models are allowed (deny-by-default) if len(vk.ProviderConfigs) == 0 { - return models + return []schemas.Model{} } // Filter models based on ProviderConfigs diff --git a/plugins/governance/version b/plugins/governance/version index 62f0c2cadb..8e03717dca 100644 --- a/plugins/governance/version +++ b/plugins/governance/version @@ -1 +1 @@ -1.4.36 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/plugins/jsonparser/changelog.md b/plugins/jsonparser/changelog.md index e69de29bb2..9d094203da 100644 --- a/plugins/jsonparser/changelog.md +++ b/plugins/jsonparser/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.1 and framework to v1.3.1 diff --git a/plugins/jsonparser/go.mod b/plugins/jsonparser/go.mod index a2c9947112..b8ba648784 100644 --- a/plugins/jsonparser/go.mod +++ b/plugins/jsonparser/go.mod @@ -2,7 +2,7 @@ module github.com/maximhq/bifrost/plugins/jsonparser go 1.26.1 -require github.com/maximhq/bifrost/core v1.4.17 +require github.com/maximhq/bifrost/core v1.5.1 require ( cloud.google.com/go v0.123.0 // indirect diff --git a/plugins/jsonparser/go.sum b/plugins/jsonparser/go.sum index 98dae165ac..b96c86a77e 100644 --- a/plugins/jsonparser/go.sum +++ b/plugins/jsonparser/go.sum @@ -109,8 +109,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/plugins/jsonparser/version b/plugins/jsonparser/version index de17646bc0..8e03717dca 100644 --- a/plugins/jsonparser/version +++ b/plugins/jsonparser/version @@ -1 +1 @@ -1.4.35 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/plugins/litellmcompat/changelog.md b/plugins/litellmcompat/changelog.md index e69de29bb2..9d094203da 100644 --- a/plugins/litellmcompat/changelog.md +++ b/plugins/litellmcompat/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.1 and framework to v1.3.1 diff --git a/plugins/litellmcompat/go.mod b/plugins/litellmcompat/go.mod index c15599320d..c873e5bba1 100644 --- a/plugins/litellmcompat/go.mod +++ b/plugins/litellmcompat/go.mod @@ -3,8 +3,8 @@ module github.com/maximhq/bifrost/plugins/litellmcompat go 1.26.1 require ( - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 ) require ( diff --git a/plugins/litellmcompat/go.sum b/plugins/litellmcompat/go.sum index f993012bf1..6f1eeb5001 100644 --- a/plugins/litellmcompat/go.sum +++ b/plugins/litellmcompat/go.sum @@ -193,10 +193,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/plugins/litellmcompat/texttochat.go b/plugins/litellmcompat/texttochat.go index 2eb7446348..b0c1b0a309 100644 --- a/plugins/litellmcompat/texttochat.go +++ b/plugins/litellmcompat/texttochat.go @@ -83,7 +83,7 @@ func transformTextToChatResponse(_ *schemas.BifrostContext, resp *schemas.Bifros // Restore original request type metadata textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.ModelRequested = tc.OriginalModel + textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel textCompletionResponse.ExtraFields.LiteLLMCompat = true if logger != nil { @@ -110,7 +110,7 @@ func transformTextToChatError(_ *schemas.BifrostContext, err *schemas.BifrostErr // Restore original request type in error metadata err.ExtraFields.RequestType = tc.OriginalRequestType - err.ExtraFields.ModelRequested = tc.OriginalModel + err.ExtraFields.OriginalModelRequested = tc.OriginalModel err.ExtraFields.LiteLLMCompat = true return err @@ -141,7 +141,7 @@ func TransformTextToChatStreamResponse(ctx *schemas.BifrostContext, stream *sche // Restore original request type metadata textCompletionResponse.ExtraFields.RequestType = tc.OriginalRequestType - textCompletionResponse.ExtraFields.ModelRequested = tc.OriginalModel + textCompletionResponse.ExtraFields.OriginalModelRequested = tc.OriginalModel textCompletionResponse.ExtraFields.LiteLLMCompat = true // Return a new stream with the text completion response diff --git a/plugins/litellmcompat/version b/plugins/litellmcompat/version index d34586a15a..6da28dde76 100644 --- a/plugins/litellmcompat/version +++ b/plugins/litellmcompat/version @@ -1 +1 @@ -0.0.25 \ No newline at end of file +0.1.1 \ No newline at end of file diff --git a/plugins/logging/changelog.md b/plugins/logging/changelog.md index e69de29bb2..c7ab9d714c 100644 --- a/plugins/logging/changelog.md +++ b/plugins/logging/changelog.md @@ -0,0 +1,4 @@ +- feat: add realtime turn logging +- feat: add support for tracking userId, teamId, customerId, and businessUnitId +- feat: allow path whitelisting from security config +- fix: MCP tool logs not being captured correctly diff --git a/plugins/logging/go.mod b/plugins/logging/go.mod index 204798b229..62da1daa80 100644 --- a/plugins/logging/go.mod +++ b/plugins/logging/go.mod @@ -4,8 +4,8 @@ go 1.26.1 require ( github.com/bytedance/sonic v1.15.0 - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 ) require ( diff --git a/plugins/logging/go.sum b/plugins/logging/go.sum index f993012bf1..6f1eeb5001 100644 --- a/plugins/logging/go.sum +++ b/plugins/logging/go.sum @@ -193,10 +193,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 749aa30599..c8f5a14399 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -6,6 +6,7 @@ package logging import ( "context" "fmt" + "math" "strings" "sync" "sync/atomic" @@ -99,6 +100,7 @@ func applyLargePayloadPreviewsToEntry(ctx *schemas.BifrostContext, entry *logsto } } +// scheduleDeferredUsageUpdate schedules a deferred usage update for the request. func (p *LoggerPlugin) scheduleDeferredUsageUpdate(ctx *schemas.BifrostContext, requestID string, usageAlreadyPresent bool) { if usageAlreadyPresent || ctx == nil { return @@ -108,7 +110,6 @@ func (p *LoggerPlugin) scheduleDeferredUsageUpdate(ctx *schemas.BifrostContext, if !ok || deferredChan == nil { return } - p.wg.Add(1) go func() { defer p.wg.Done() @@ -127,7 +128,6 @@ func (p *LoggerPlugin) scheduleDeferredUsageUpdate(ctx *schemas.BifrostContext, p.logger.Warn("deferred usage update dropped for request %s: semaphore full", requestID) return } - usageUpdates := map[string]interface{}{ "prompt_tokens": deferredUsage.PromptTokens, "completion_tokens": deferredUsage.CompletionTokens, @@ -138,6 +138,27 @@ func (p *LoggerPlugin) scheduleDeferredUsageUpdate(ctx *schemas.BifrostContext, usageUpdates["token_usage"] = tempEntry.TokenUsage usageUpdates["cached_read_tokens"] = tempEntry.CachedReadTokens } + + // Check if log entry present in the store + // exponential backoff with jitter and 3 retries + // then fail + var found bool + var findErr error + for i := 0; i < 3; i++ { + found, findErr = p.store.IsLogEntryPresent(p.ctx, requestID) + if findErr != nil { + p.logger.Warn("failed to check if log entry is present for request %s: %v", requestID, findErr) + continue + } + if found { + break + } + time.Sleep(time.Duration(math.Pow(2, float64(i))) * time.Second * 2) + } + if !found { + p.logger.Warn("log entry not found for request %s after 3 retries. failed to update deferred usage for large payload request", requestID) + return + } if updErr := p.store.Update(p.ctx, requestID, usageUpdates); updErr != nil { p.logger.Warn("failed to update deferred usage for request %s: %v", requestID, updErr) } @@ -183,14 +204,16 @@ type InitialLogData struct { Object string InputHistory []schemas.ChatMessage ResponsesInputHistory []schemas.ResponsesMessage - Params interface{} + Params any SpeechInput *schemas.SpeechInput TranscriptionInput *schemas.TranscriptionInput ImageGenerationInput *schemas.ImageGenerationInput + ImageEditInput *schemas.ImageEditInput + ImageVariationInput *schemas.ImageVariationInput VideoGenerationInput *schemas.VideoGenerationInput Tools []schemas.ChatTool RoutingEngineUsed []string - Metadata map[string]interface{} + Metadata map[string]any PassthroughRequestBody string // Raw body for passthrough requests (UTF-8) } @@ -224,7 +247,8 @@ type LoggerPlugin struct { cleanupTicker *time.Ticker // Ticker for cleaning up old processing logs logMsgPool sync.Pool // Pool for reusing LogMessage structs updateDataPool sync.Pool // Pool for reusing UpdateLogData structs - pendingLogs sync.Map // Maps requestID -> *PendingLogData (PreLLMHook input data awaiting PostLLMHook) + pendingLogsEntries sync.Map // Maps requestID -> *PendingLogData (PreLLMHook input data awaiting PostLLMHook) + pendingLogsToInject sync.Map // Maps traceID -> *pendingInjectEntries (log entries to inject, supports multiple per trace) writeQueue chan *writeQueueEntry // Buffered channel for batch write queue closed atomic.Bool // Set during cleanup to prevent sends on closed writeQueue deferredUsageSem chan struct{} // Limits concurrent deferred usage DB updates @@ -257,12 +281,12 @@ func Init(ctx context.Context, config *Config, logger schemas.Logger, logsStore writeQueue: make(chan *writeQueueEntry, writeQueueCapacity), deferredUsageSem: make(chan struct{}, maxDeferredUsageConcurrency), logMsgPool: sync.Pool{ - New: func() interface{} { + New: func() any { return &LogMessage{} }, }, updateDataPool: sync.Pool{ - New: func() interface{} { + New: func() any { return &UpdateLogData{} }, }, @@ -425,6 +449,9 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr Model: model, Object: string(req.RequestType), } + if req.RequestType == schemas.RealtimeRequest { + initialData.Object = "realtime.turn" + } if p.disableContentLogging == nil || !*p.disableContentLogging { inputHistory, responsesInputHistory := p.extractInputHistory(req) @@ -445,6 +472,10 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr tools = append(tools, *tool.ToChatTool()) } initialData.Tools = tools + case schemas.RealtimeRequest: + if req.ResponsesRequest != nil { + initialData.Params = req.ResponsesRequest.Params + } case schemas.EmbeddingRequest: initialData.Params = req.EmbeddingRequest.Params case schemas.RerankRequest: @@ -470,6 +501,50 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest: initialData.Params = req.ImageGenerationRequest.Params initialData.ImageGenerationInput = req.ImageGenerationRequest.Input + case schemas.ImageEditRequest, schemas.ImageEditStreamRequest: + params := req.ImageEditRequest.Params + input := req.ImageEditRequest.Input + if input != nil { + reqThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadRequestThreshold).(int64) + if reqThreshold > 0 { + var totalSize int64 + for _, img := range input.Images { + totalSize += int64(len(img.Image)) + } + if totalSize > reqThreshold { + logInput := *input + logInput.Images = nil + initialData.ImageEditInput = &logInput + } else { + initialData.ImageEditInput = input + } + if params != nil && int64(len(params.Mask)) > reqThreshold { + logParams := *params + logParams.Mask = nil + initialData.Params = &logParams + } else { + initialData.Params = params + } + } else { + initialData.ImageEditInput = input + initialData.Params = params + } + } else { + initialData.Params = params + } + case schemas.ImageVariationRequest: + initialData.Params = req.ImageVariationRequest.Params + input := req.ImageVariationRequest.Input + if input != nil { + reqThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadRequestThreshold).(int64) + if reqThreshold > 0 && int64(len(input.Image.Image)) > reqThreshold { + logInput := *input + logInput.Image = schemas.ImageInput{} + initialData.ImageVariationInput = &logInput + } else { + initialData.ImageVariationInput = input + } + } case schemas.VideoGenerationRequest: initialData.Params = req.VideoGenerationRequest.Params initialData.VideoGenerationInput = req.VideoGenerationRequest.Input @@ -506,7 +581,7 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr } // Capture configured logging headers and x-bf-lh-* headers into metadata first - initialData.Metadata = p.captureLoggingHeaders(ctx) + initialData.Metadata = mergeRealtimeMetadata(p.captureLoggingHeaders(ctx), ctx) // System entries are set after so they take precedence over dynamic header values if isAsync, ok := ctx.Value(schemas.BifrostIsAsyncRequest).(bool); ok && isAsync { @@ -524,10 +599,15 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr // Determine effective request ID (fallback override) effectiveRequestID := requestID var parentRequestID string + if directParentRequestID, ok := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && directParentRequestID != "" { + parentRequestID = directParentRequestID + } fallbackRequestID, ok := ctx.Value(schemas.BifrostContextKeyFallbackRequestID).(string) if ok && fallbackRequestID != "" { effectiveRequestID = fallbackRequestID - parentRequestID = requestID + if parentRequestID == "" { + parentRequestID = requestID + } } fallbackIndex := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyFallbackIndex) @@ -552,7 +632,7 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr CreatedAt: time.Now(), Status: "processing", } - p.pendingLogs.Store(effectiveRequestID, pending) + p.pendingLogsEntries.Store(effectiveRequestID, pending) // Call callback synchronously for immediate UI feedback (WebSocket "processing" notification). // The entry does not exist in the DB yet - it will be written when PostLLMHook fires. p.mu.Lock() @@ -596,15 +676,22 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. virtualKeyName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceVirtualKeyName) routingRuleID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceRoutingRuleID) routingRuleName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceRoutingRuleName) + teamID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceTeamID) + teamName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceTeamName) + customerID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceCustomerID) + customerName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceCustomerName) + userID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceUserID) + businessUnitID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceBusinessUnitID) + businessUnitName := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyGovernanceBusinessUnitName) numberOfRetries := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyNumberOfRetries) - requestType, _, _ := bifrost.GetResponseFields(result, bifrostErr) + requestType, _, originalModelRequested, resolvedModelUsed := bifrost.GetResponseFields(result, bifrostErr) isFinalChunk := bifrost.IsFinalChunk(ctx) var tracer schemas.Tracer var traceID string - if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest { + if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest && requestType != schemas.RealtimeRequest { var err error tracer, traceID, err = bifrost.GetTracerFromContext(ctx) if err != nil { @@ -618,7 +705,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // and skip the write queue entirely. The accumulator work (ProcessStreamingChunk) // is fast (mutex + append). Only final chunks, errors, and non-streaming // responses need a DB write. - if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest && !isFinalChunk && result != nil && bifrostErr == nil { + if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest && requestType != schemas.RealtimeRequest && !isFinalChunk && result != nil && bifrostErr == nil { if tracer != nil && traceID != "" { tracer.ProcessStreamingChunk(traceID, false, result, bifrostErr) } @@ -628,7 +715,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. routingEngineLogs := formatRoutingEngineLogs(ctx.GetRoutingEngineLogs()) // Retrieve pending input data from PreLLMHook - pendingVal, hasPending := p.pendingLogs.LoadAndDelete(requestID) + pendingVal, hasPending := p.pendingLogsEntries.LoadAndDelete(requestID) if !hasPending { // If we have an error (e.g., cancellation/timeout), still write a minimal error entry // so the error is visible in logs. Without PreLLMHook's DB insert, silently returning @@ -638,18 +725,18 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. entry := &logstore.Log{ ID: requestID, Provider: string(bifrostErr.ExtraFields.Provider), - Model: bifrostErr.ExtraFields.ModelRequested, Status: "error", Stream: bifrost.IsStreamRequestType(requestType), Timestamp: time.Now().UTC(), CreatedAt: time.Now().UTC(), } + applyModelAlias(entry, bifrostErr.ExtraFields.OriginalModelRequested, bifrostErr.ExtraFields.ResolvedModelUsed) if data, err := sonic.Marshal(bifrostErr); err == nil { entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr applyLargePayloadPreviewsToEntry(ctx, entry) - p.enqueueLogEntry(entry, p.makePostWriteCallback(nil)) + p.storeOrEnqueueEntry(ctx, entry, p.makePostWriteCallback(nil)) } else { p.logger.Warn("no pending log data found for request %s, skipping log write", requestID) } @@ -657,6 +744,11 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } pending := pendingVal.(*PendingLogData) + if requestType == schemas.RealtimeRequest { + if resolvedRealtimeSessionID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRealtimeSessionID); resolvedRealtimeSessionID != "" { + pending.ParentRequestID = resolvedRealtimeSessionID + } + } // Build the complete log entry with input (from PreLLMHook) + output (from PostLLMHook) entry := buildCompleteLogEntryFromPending(pending) @@ -665,8 +757,9 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. if result != nil { latency = result.GetExtraFields().Latency } - applyOutputFieldsToEntry(entry, selectedKeyID, selectedKeyName, virtualKeyID, virtualKeyName, routingRuleID, routingRuleName, numberOfRetries, latency) + applyOutputFieldsToEntry(entry, selectedKeyID, selectedKeyName, virtualKeyID, virtualKeyName, routingRuleID, routingRuleName, teamID, teamName, customerID, customerName, userID, businessUnitID, businessUnitName, numberOfRetries, latency) entry.MetadataParsed = pending.InitialData.Metadata + entry.MetadataParsed = mergeRealtimeMetadata(entry.MetadataParsed, ctx) entry.RoutingEngineLogs = routingEngineLogs // Branch based on response type to populate output-specific fields @@ -674,6 +767,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // Path A: Error with nil result if result == nil && bifrostErr != nil { entry.Status = "error" + applyModelAlias(entry, bifrostErr.ExtraFields.OriginalModelRequested, bifrostErr.ExtraFields.ResolvedModelUsed) if bifrost.IsStreamRequestType(requestType) { entry.Stream = true } @@ -700,13 +794,13 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } } applyLargePayloadPreviewsToEntry(ctx, entry) - p.enqueueLogEntry(entry, p.makePostWriteCallback(nil)) + p.storeOrEnqueueEntry(ctx, entry, p.makePostWriteCallback(nil)) p.scheduleDeferredUsageUpdate(ctx, requestID, entry.TokenUsageParsed != nil) return result, bifrostErr, nil } // Path B: Streaming final chunk - if bifrost.IsStreamRequestType(requestType) { + if bifrost.IsStreamRequestType(requestType) && requestType != schemas.RealtimeRequest { var streamResponse *streaming.ProcessedStreamResponse if requestType != schemas.PassthroughStreamRequest && tracer != nil && traceID != "" { accResult := tracer.ProcessStreamingChunk(traceID, isFinalChunk, result, bifrostErr) @@ -718,6 +812,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. if bifrostErr != nil { entry.Status = "error" entry.Stream = true + applyModelAlias(entry, originalModelRequested, resolvedModelUsed) if data, err := sonic.Marshal(bifrostErr); err == nil { entry.ErrorDetails = string(data) } @@ -726,6 +821,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // tracer or traceID not available, or accumulator returned nil - still write what we have entry.Status = "success" entry.Stream = true + applyModelAlias(entry, originalModelRequested, resolvedModelUsed) } else if isFinalChunk { // Apply streaming output fields to the entry entry.Stream = true @@ -747,7 +843,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. tracer.CleanupStreamAccumulator(traceID) } - p.enqueueLogEntry(entry, p.makePostWriteCallback(nil)) + p.storeOrEnqueueEntry(ctx, entry, p.makePostWriteCallback(nil)) p.scheduleDeferredUsageUpdate(ctx, requestID, entry.TokenUsageParsed != nil) return result, bifrostErr, nil } @@ -755,6 +851,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // Path C: Non-streaming response if bifrostErr != nil { entry.Status = "error" + applyModelAlias(entry, bifrostErr.ExtraFields.OriginalModelRequested, bifrostErr.ExtraFields.ResolvedModelUsed) // Serialize error details immediately since bifrostErr may be released // back to the pool before the async batch writer processes this entry. // Also set ErrorDetailsParsed for UI callback (JSON serialization uses this field). @@ -762,9 +859,21 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr + // Realtime turns that fail mid-stream still need their input transcript + // surfaced β€” backfill from bifrostErr.ExtraFields.RawRequest if present. + if requestType == schemas.RealtimeRequest { + contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging + applyRealtimeRawRequestBackfill(entry, bifrostErr.ExtraFields.RawRequest, contentLoggingEnabled) + } } else if result != nil { entry.Status = "success" - p.applyNonStreamingOutputToEntry(entry, result) + extraFields := result.GetExtraFields() + applyModelAlias(entry, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed) + if requestType == schemas.RealtimeRequest { + p.applyRealtimeOutputToEntry(entry, result) + } else { + p.applyNonStreamingOutputToEntry(entry, result) + } // Flip status for passthrough error responses (4xx/5xx from provider) if isPassthroughErrorResponse(result) { entry.Status = "error" @@ -779,29 +888,30 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } entry.CacheDebugParsed = cacheDebug if p.pricingManager != nil { - if cost := p.pricingManager.CalculateCost(result); cost > 0 { + pricingScopes := modelcatalog.PricingLookupScopesFromContext(ctx, string(entry.Provider)) + if cost := p.pricingManager.CalculateCost(result, pricingScopes); cost > 0 { entry.Cost = &cost } } - p.enqueueLogEntry(entry, p.makePostWriteCallback(func(updatedEntry *logstore.Log) { - updatedEntry.SelectedKey = &schemas.Key{ - ID: updatedEntry.SelectedKeyID, - Name: updatedEntry.SelectedKeyName, - } - if updatedEntry.VirtualKeyID != nil && updatedEntry.VirtualKeyName != nil { - updatedEntry.VirtualKey = &tables.TableVirtualKey{ - ID: *updatedEntry.VirtualKeyID, - Name: *updatedEntry.VirtualKeyName, - } + // Pre-apply denormalized fields for WebSocket callback enrichment + entry.SelectedKey = &schemas.Key{ + ID: entry.SelectedKeyID, + Name: entry.SelectedKeyName, + } + if entry.VirtualKeyID != nil && entry.VirtualKeyName != nil { + entry.VirtualKey = &tables.TableVirtualKey{ + ID: *entry.VirtualKeyID, + Name: *entry.VirtualKeyName, } - if updatedEntry.RoutingRuleID != nil && updatedEntry.RoutingRuleName != nil { - updatedEntry.RoutingRule = &tables.TableRoutingRule{ - ID: *updatedEntry.RoutingRuleID, - Name: *updatedEntry.RoutingRuleName, - } + } + if entry.RoutingRuleID != nil && entry.RoutingRuleName != nil { + entry.RoutingRule = &tables.TableRoutingRule{ + ID: *entry.RoutingRuleID, + Name: *entry.RoutingRuleName, } - })) + } + p.storeOrEnqueueEntry(ctx, entry, p.makePostWriteCallback(nil)) p.scheduleDeferredUsageUpdate(ctx, requestID, entry.TokenUsageParsed != nil) return result, bifrostErr, nil } @@ -828,6 +938,60 @@ func (p *LoggerPlugin) Cleanup() error { return nil } +// storeOrEnqueueEntry stores a log entry in pendingLogs keyed by traceID for later +// retrieval by Inject(), or enqueues directly if no traceID is available (Go SDK path). +// Multiple entries per traceID are supported (e.g. fallback/retry attempts within the same trace). +func (p *LoggerPlugin) storeOrEnqueueEntry(ctx *schemas.BifrostContext, entry *logstore.Log, callback func(entry *logstore.Log)) { + traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if traceID != "" { + // Append to slice for Inject() to pick up β€” supports multiple attempts per trace + existing, loaded := p.pendingLogsToInject.LoadOrStore(traceID, &pendingInjectEntries{entries: []*logstore.Log{entry}, createdAt: time.Now()}) + if !loaded { + return + } + pending := existing.(*pendingInjectEntries) + pending.mu.Lock() + pending.entries = append(pending.entries, entry) + pending.mu.Unlock() + } else { + // Fallback: no tracing (Go SDK path), enqueue directly + p.enqueueLogEntry(entry, callback) + } +} + +// Inject receives a completed trace and writes the log entries with plugin logs to DB. +// This implements the ObservabilityPlugin interface. +func (p *LoggerPlugin) Inject(_ context.Context, trace *schemas.Trace) error { + if trace == nil { + return nil + } + // Retrieve pending log entries built by PostLLMHook (stored by traceID) + entryVal, ok := p.pendingLogsToInject.LoadAndDelete(trace.TraceID) + if !ok { + return nil + } + pending, ok := entryVal.(*pendingInjectEntries) + if !ok { + return nil + } + + // Serialize plugin logs once for all entries + var pluginLogsJSON string + if len(trace.PluginLogs) > 0 { + grouped := schemas.GroupPluginLogsByName(trace.PluginLogs) + if data, err := sonic.Marshal(grouped); err == nil { + pluginLogsJSON = string(data) + } + } + + // Enqueue each log entry (supports multiple attempts per trace) + for _, entry := range pending.entries { + entry.PluginLogs = pluginLogsJSON + p.enqueueLogEntry(entry, p.makePostWriteCallback(nil)) + } + return nil +} + // MCP Plugin Interface Implementation // SetMCPToolLogCallback sets a callback function that will be called for each MCP tool log entry diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go index f37b04c004..cf65beb9c1 100644 --- a/plugins/logging/operations.go +++ b/plugins/logging/operations.go @@ -4,14 +4,18 @@ package logging import ( "context" "fmt" + "strings" "time" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/framework/modelcatalog" "github.com/maximhq/bifrost/framework/streaming" ) +const realtimeMissingTranscriptText = "[Audio transcription unavailable]" + // insertInitialLogEntry creates a new log entry in the database using GORM func (p *LoggerPlugin) insertInitialLogEntry( ctx context.Context, @@ -40,6 +44,8 @@ func (p *LoggerPlugin) insertInitialLogEntry( SpeechInputParsed: data.SpeechInput, TranscriptionInputParsed: data.TranscriptionInput, ImageGenerationInputParsed: data.ImageGenerationInput, + ImageEditInputParsed: data.ImageEditInput, + ImageVariationInputParsed: data.ImageVariationInput, RoutingEnginesUsed: routingEnginesUsed, MetadataParsed: data.Metadata, VideoGenerationInputParsed: data.VideoGenerationInput, @@ -401,10 +407,13 @@ func (p *LoggerPlugin) updateStreamingLogEntry( tempEntry := &logstore.Log{} updates["latency"] = float64(streamResponse.Data.Latency) - // Update model if provided - if streamResponse.Data.Model != "" { - updates["model"] = streamResponse.Data.Model + // Update model and alias from resolved/requested model pair. + tempEntry2 := &logstore.Log{} + applyModelAlias(tempEntry2, streamResponse.RequestedModel, streamResponse.ResolvedModel) + if tempEntry2.Model != "" { + updates["model"] = tempEntry2.Model } + updates["alias"] = tempEntry2.Alias needsSerialization := false @@ -541,10 +550,8 @@ func (p *LoggerPlugin) applyStreamingOutputToEntry(entry *logstore.Log, streamRe latF := float64(streamResponse.Data.Latency) entry.Latency = &latF - // Update model if provided - if streamResponse.Data.Model != "" { - entry.Model = streamResponse.Data.Model - } + // Update model and alias from resolved/requested model pair. + applyModelAlias(entry, streamResponse.RequestedModel, streamResponse.ResolvedModel) // Token usage if streamResponse.Data.TokenUsage != nil { @@ -718,6 +725,350 @@ func (p *LoggerPlugin) applyNonStreamingOutputToEntry(entry *logstore.Log, resul } } +func (p *LoggerPlugin) applyRealtimeOutputToEntry(entry *logstore.Log, result *schemas.BifrostResponse) { + if result == nil || result.ResponsesResponse == nil { + return + } + + if usage := result.ResponsesResponse.Usage; usage != nil { + bifrostUsage := usage.ToBifrostLLMUsage() + entry.TokenUsageParsed = bifrostUsage + entry.PromptTokens = bifrostUsage.PromptTokens + entry.CompletionTokens = bifrostUsage.CompletionTokens + entry.TotalTokens = bifrostUsage.TotalTokens + } + + contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging + + if contentLoggingEnabled { + if outputMessage := extractRealtimeOutputMessage(result.ResponsesResponse.Output); outputMessage != nil { + entry.OutputMessageParsed = outputMessage + } + } + + extraFields := result.GetExtraFields() + applyRealtimeRawRequestBackfill(entry, extraFields.RawRequest, contentLoggingEnabled) + if contentLoggingEnabled && extraFields.RawResponse != nil { + switch raw := extraFields.RawResponse.(type) { + case string: + entry.RawResponse = strings.TrimSpace(raw) + default: + if rawResponseBytes, err := sonic.Marshal(extraFields.RawResponse); err == nil { + entry.RawResponse = string(rawResponseBytes) + } + } + } +} + +// applyRealtimeRawRequestBackfill writes RawRequest onto entry from an +// ExtraFields.RawRequest value (string or marshalable) and rebuilds +// InputHistoryParsed from any embedded realtime user/transcript events. +// Used by both success and error paths so realtime turns that fail mid-stream +// still surface their input transcript in logs. +func applyRealtimeRawRequestBackfill(entry *logstore.Log, rawRequest any, contentLoggingEnabled bool) { + if !contentLoggingEnabled || rawRequest == nil { + return + } + switch raw := rawRequest.(type) { + case string: + entry.RawRequest = strings.TrimSpace(raw) + default: + if rawRequestBytes, err := sonic.Marshal(rawRequest); err == nil { + entry.RawRequest = string(rawRequestBytes) + } + } + if strings.TrimSpace(entry.RawRequest) == "" { + return + } + if inputHistory := extractRealtimeInputHistoryFromRawRequest(entry.RawRequest); len(inputHistory) > 0 { + entry.InputHistoryParsed = mergeRealtimeInputHistory(entry.InputHistoryParsed, inputHistory) + } +} + +func extractRealtimeInputHistoryFromRawRequest(rawRequest string) []schemas.ChatMessage { + rawRequest = strings.TrimSpace(rawRequest) + if rawRequest == "" { + return nil + } + + parts := strings.Split(rawRequest, "\n\n") + messages := make([]schemas.ChatMessage, 0, len(parts)) + for _, part := range parts { + event, err := schemas.ParseRealtimeEvent([]byte(strings.TrimSpace(part))) + if err != nil || event == nil { + continue + } + + switch { + case schemas.IsRealtimeInputTranscriptEvent(event): + if transcript := extractRealtimeTranscript(event); transcript != "" { + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(transcript), + }, + }) + } + case schemas.IsRealtimeUserInputEvent(event): + if content := extractRealtimeRawItemContent(event.Item); content != "" { + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + }) + } + case schemas.IsRealtimeToolOutputEvent(event): + if content := extractRealtimeRawItemContent(event.Item); content != "" { + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr(event.Item.CallID), + }, + }) + } + } + } + + if len(messages) == 0 { + return nil + } + return messages +} + +func mergeRealtimeInputHistory(existing, backfill []schemas.ChatMessage) []schemas.ChatMessage { + if len(backfill) == 0 { + return existing + } + + // Run dedupe even when existing is empty so duplicate events inside the + // same raw-event blob (same turn captured twice) collapse instead of + // getting written out verbatim. + merged := append([]schemas.ChatMessage(nil), existing...) + for _, candidate := range backfill { + if realtimeInputHistoryContainsEquivalent(merged, candidate) { + continue + } + if candidate.Role == schemas.ChatMessageRoleUser { + inserted := false + for idx, msg := range merged { + if msg.Role == schemas.ChatMessageRoleTool { + merged = append(merged[:idx], append([]schemas.ChatMessage{candidate}, merged[idx:]...)...) + inserted = true + break + } + } + if inserted { + continue + } + } + merged = append(merged, candidate) + } + return merged +} + +func realtimeInputHistoryContainsEquivalent(history []schemas.ChatMessage, candidate schemas.ChatMessage) bool { + candidateContent := strings.TrimSpace(realtimeInputHistoryMessageContent(candidate)) + candidateToolCallID := strings.TrimSpace(realtimeInputHistoryToolCallID(candidate)) + + for _, existing := range history { + if existing.Role != candidate.Role { + continue + } + if strings.TrimSpace(realtimeInputHistoryMessageContent(existing)) != candidateContent { + continue + } + if strings.TrimSpace(realtimeInputHistoryToolCallID(existing)) != candidateToolCallID { + continue + } + return true + } + + return false +} + +func realtimeInputHistoryMessageContent(message schemas.ChatMessage) string { + if message.Content == nil || message.Content.ContentStr == nil { + return "" + } + return *message.Content.ContentStr +} + +func realtimeInputHistoryToolCallID(message schemas.ChatMessage) string { + if message.ChatToolMessage == nil || message.ChatToolMessage.ToolCallID == nil { + return "" + } + return *message.ChatToolMessage.ToolCallID +} + +func extractRealtimeTranscript(event *schemas.BifrostRealtimeEvent) string { + if event == nil || event.ExtraParams == nil { + return realtimeMissingTranscriptText + } + raw, ok := event.ExtraParams["transcript"] + if !ok || len(raw) == 0 { + return realtimeMissingTranscriptText + } + var transcript string + if err := schemas.Unmarshal(raw, &transcript); err != nil { + return realtimeMissingTranscriptText + } + transcript = strings.TrimSpace(transcript) + if transcript == "" { + return realtimeMissingTranscriptText + } + return transcript +} + +func extractRealtimeRawItemContent(item *schemas.RealtimeItem) string { + if item == nil { + return "" + } + if content := extractRealtimeRawContent(item.Content); content != "" { + return content + } + if item.Role == "user" && realtimeItemHasMissingAudioTranscript(item) { + return realtimeMissingTranscriptText + } + switch { + case strings.TrimSpace(item.Output) != "": + return strings.TrimSpace(item.Output) + case strings.TrimSpace(item.Arguments) != "": + return strings.TrimSpace(item.Arguments) + default: + return "" + } +} + +func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool { + if item == nil || len(item.Content) == 0 { + return false + } + + var decoded []map[string]any + if err := sonic.Unmarshal(item.Content, &decoded); err != nil { + return false + } + + for _, part := range decoded { + partType, _ := part["type"].(string) + if partType != "input_audio" { + continue + } + transcript, exists := part["transcript"] + if !exists || transcript == nil { + return true + } + if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" { + return true + } + } + + return false +} + +func extractRealtimeRawContent(raw []byte) string { + if len(raw) == 0 { + return "" + } + + var decoded any + if err := sonic.Unmarshal(raw, &decoded); err != nil { + return strings.TrimSpace(string(raw)) + } + + var parts []string + collectRealtimeRawTextFragments(decoded, &parts) + return strings.TrimSpace(strings.Join(parts, " ")) +} + +func collectRealtimeRawTextFragments(value any, parts *[]string) { + switch v := value.(type) { + case map[string]any: + for key, field := range v { + switch key { + case "text", "transcript", "input_text", "output_text", "output", "arguments": + if text, ok := field.(string); ok { + text = strings.TrimSpace(text) + if text != "" { + *parts = append(*parts, text) + } + continue + } + } + collectRealtimeRawTextFragments(field, parts) + } + case []any: + for _, item := range v { + collectRealtimeRawTextFragments(item, parts) + } + } +} + +func extractRealtimeOutputMessage(output []schemas.ResponsesMessage) *schemas.ChatMessage { + var contentParts []string + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, item := range output { + if item.Type == nil { + continue + } + switch *item.Type { + case schemas.ResponsesMessageTypeMessage: + if item.Role == nil || *item.Role != schemas.ResponsesInputMessageRoleAssistant { + continue + } + if text := extractRealtimeResponsesContent(item.Content); text != "" { + contentParts = append(contentParts, text) + } + case schemas.ResponsesMessageTypeFunctionCall: + if item.ResponsesToolMessage == nil || item.ResponsesToolMessage.Name == nil { + continue + } + toolType := "function" + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: item.ResponsesToolMessage.Name, + Arguments: derefString(item.ResponsesToolMessage.Arguments), + }, + } + if item.CallID != nil && strings.TrimSpace(*item.CallID) != "" { + toolCall.ID = schemas.Ptr(strings.TrimSpace(*item.CallID)) + } else if item.ID != nil && strings.TrimSpace(*item.ID) != "" { + toolCall.ID = schemas.Ptr(strings.TrimSpace(*item.ID)) + } + toolCalls = append(toolCalls, toolCall) + } + } + + if len(contentParts) == 0 && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if len(contentParts) > 0 { + content := strings.Join(contentParts, "\n") + message.Content = &schemas.ChatMessageContent{ContentStr: &content} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + } + } + return message +} + +func derefString(value *string) string { + if value == nil { + return "" + } + return *value +} + // SearchLogs searches logs with filters and pagination using GORM func (p *LoggerPlugin) SearchLogs(ctx context.Context, filters logstore.SearchFilters, pagination logstore.PaginationOptions) (*logstore.SearchResult, error) { // Set default pagination if not provided @@ -734,6 +1085,25 @@ func (p *LoggerPlugin) SearchLogs(ctx context.Context, filters logstore.SearchFi return p.store.SearchLogs(ctx, filters, pagination) } +// GetSessionLogs returns paginated logs for a single parent_request_id session. +func (p *LoggerPlugin) GetSessionLogs(ctx context.Context, sessionID string, pagination logstore.PaginationOptions) (*logstore.SessionDetailResult, error) { + if pagination.Limit == 0 { + pagination.Limit = 50 + } + if pagination.SortBy == "" { + pagination.SortBy = "timestamp" + } + if pagination.Order == "" { + pagination.Order = "asc" + } + return p.store.GetSessionLogs(ctx, sessionID, pagination) +} + +// GetSessionSummary returns aggregate totals for a single parent_request_id session. +func (p *LoggerPlugin) GetSessionSummary(ctx context.Context, sessionID string) (*logstore.SessionSummaryResult, error) { + return p.store.GetSessionSummary(ctx, sessionID) +} + // GetLog retrieves a single log entry by ID including all fields (raw_request, raw_response). func (p *LoggerPlugin) GetLog(ctx context.Context, id string) (*logstore.Log, error) { return p.store.FindByID(ctx, id) @@ -799,6 +1169,16 @@ func (p *LoggerPlugin) GetAvailableModels(ctx context.Context) []string { return models } +// GetAvailableAliases returns all unique alias values from logs. +func (p *LoggerPlugin) GetAvailableAliases(ctx context.Context) []string { + aliases, err := p.store.GetDistinctAliases(ctx) + if err != nil { + p.logger.Error("failed to get available aliases: %v", err) + return []string{} + } + return aliases +} + func (p *LoggerPlugin) GetAvailableSelectedKeys(ctx context.Context) []KeyPair { results, err := p.store.GetDistinctKeyPairs(ctx, "selected_key_id", "selected_key_name") if err != nil { @@ -826,6 +1206,68 @@ func (p *LoggerPlugin) GetAvailableRoutingRules(ctx context.Context) []KeyPair { return keyPairResultsToKeyPairs(results) } +// GetAvailableTeams returns all unique team ID-Name pairs from logs. +// Uses DISTINCT to avoid loading all rows when only unique values are needed. +func (p *LoggerPlugin) GetAvailableTeams(ctx context.Context) []KeyPair { + results, err := p.store.GetDistinctKeyPairs(ctx, "team_id", "team_name") + if err != nil { + p.logger.Error("failed to get available teams: %v", err) + return []KeyPair{} + } + return keyPairResultsToKeyPairs(results) +} + +// GetAvailableCustomers returns all unique customer ID-Name pairs from logs. +// Uses DISTINCT to avoid loading all rows when only unique values are needed. +func (p *LoggerPlugin) GetAvailableCustomers(ctx context.Context) []KeyPair { + results, err := p.store.GetDistinctKeyPairs(ctx, "customer_id", "customer_name") + if err != nil { + p.logger.Error("failed to get available customers: %v", err) + return []KeyPair{} + } + return keyPairResultsToKeyPairs(results) +} + +// GetAvailableUsers returns all unique user IDs from logs. +// Both ID and Name are set to user_id since users don't have a separate name column. +func (p *LoggerPlugin) GetAvailableUsers(ctx context.Context) []KeyPair { + results, err := p.store.GetDistinctKeyPairs(ctx, "user_id", "user_id") + if err != nil { + p.logger.Error("failed to get available users: %v", err) + return []KeyPair{} + } + return keyPairResultsToKeyPairs(results) +} + +// GetAvailableBusinessUnits returns all unique business unit ID-Name pairs from logs. +// Uses DISTINCT to avoid loading all rows when only unique values are needed. +func (p *LoggerPlugin) GetAvailableBusinessUnits(ctx context.Context) []KeyPair { + results, err := p.store.GetDistinctKeyPairs(ctx, "business_unit_id", "business_unit_name") + if err != nil { + p.logger.Error("failed to get available business units: %v", err) + return []KeyPair{} + } + return keyPairResultsToKeyPairs(results) +} + +// GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension. +// Delegates to the underlying log store which uses materialized views on PostgreSQL for performance. +func (p *LoggerPlugin) GetDimensionCostHistogram(ctx context.Context, filters logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionCostHistogramResult, error) { + return p.store.GetDimensionCostHistogram(ctx, filters, bucketSizeSeconds, dimension) +} + +// GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension. +// Delegates to the underlying log store which uses materialized views on PostgreSQL for performance. +func (p *LoggerPlugin) GetDimensionTokenHistogram(ctx context.Context, filters logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionTokenHistogramResult, error) { + return p.store.GetDimensionTokenHistogram(ctx, filters, bucketSizeSeconds, dimension) +} + +// GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension. +// Delegates to the underlying log store which uses materialized views on PostgreSQL for performance. +func (p *LoggerPlugin) GetDimensionLatencyHistogram(ctx context.Context, filters logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionLatencyHistogramResult, error) { + return p.store.GetDimensionLatencyHistogram(ctx, filters, bucketSizeSeconds, dimension) +} + // GetAvailableRoutingEngines returns all unique routing engine types used in logs. // Uses DISTINCT to avoid loading all rows when only unique values are needed. func (p *LoggerPlugin) GetAvailableRoutingEngines(ctx context.Context) []string { @@ -975,11 +1417,17 @@ func (p *LoggerPlugin) calculateCostForLog(logEntry *logstore.Log) (float64, err // Build a minimal BifrostResponse matching the request type so that // extractCostInput routes usage into the correct field for each compute function. + originalModelRequested := logEntry.Model + if logEntry.Alias != nil && *logEntry.Alias != "" { + originalModelRequested = *logEntry.Alias + } + extraFields := schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: schemas.ModelProvider(logEntry.Provider), - ModelRequested: logEntry.Model, - CacheDebug: cacheDebug, + RequestType: requestType, + Provider: schemas.ModelProvider(logEntry.Provider), + OriginalModelRequested: originalModelRequested, + ResolvedModelUsed: logEntry.Model, + CacheDebug: cacheDebug, } resp := buildResponseForRequestType(requestType, usage, extraFields) @@ -1025,7 +1473,8 @@ func (p *LoggerPlugin) calculateCostForLog(logEntry *logstore.Log) (float64, err resp.SpeechResponse.Usage = logEntry.SpeechOutputParsed.Usage } - return p.pricingManager.CalculateCost(resp), nil + scopes := pricingScopesForLog(logEntry) + return p.pricingManager.CalculateCost(resp, &scopes), nil } // buildResponseForRequestType wraps BifrostLLMUsage into the correct response @@ -1073,19 +1522,19 @@ func buildResponseForRequestType(requestType schemas.RequestType, usage *schemas CachedWriteTokens: usage.PromptTokensDetails.CachedWriteTokens, } } - if usage.CompletionTokensDetails != nil { - respUsage.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ - TextTokens: usage.CompletionTokensDetails.TextTokens, - AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens, - AudioTokens: usage.CompletionTokensDetails.AudioTokens, - ImageTokens: usage.CompletionTokensDetails.ImageTokens, - ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, - RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens, - CitationTokens: usage.CompletionTokensDetails.CitationTokens, - NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries, + if usage.CompletionTokensDetails != nil { + respUsage.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ + TextTokens: usage.CompletionTokensDetails.TextTokens, + AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens, + AudioTokens: usage.CompletionTokensDetails.AudioTokens, + ImageTokens: usage.CompletionTokensDetails.ImageTokens, + ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, + RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens, + CitationTokens: usage.CompletionTokensDetails.CitationTokens, + NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries, + } } } - } return &schemas.BifrostResponse{ ResponsesResponse: &schemas.BifrostResponsesResponse{ Usage: respUsage, @@ -1157,3 +1606,20 @@ func buildResponseForRequestType(requestType schemas.RequestType, usage *schemas } } } + +func pricingScopesForLog(logEntry *logstore.Log) modelcatalog.PricingLookupScopes { + if logEntry == nil { + return modelcatalog.PricingLookupScopes{} + } + + virtualKeyID := "" + if logEntry.VirtualKeyID != nil { + virtualKeyID = *logEntry.VirtualKeyID + } + + return modelcatalog.PricingLookupScopes{ + Provider: logEntry.Provider, + SelectedKeyID: logEntry.SelectedKeyID, + VirtualKeyID: virtualKeyID, + } +} diff --git a/plugins/logging/operations_test.go b/plugins/logging/operations_test.go index daa2ec91ff..7b12b33eed 100644 --- a/plugins/logging/operations_test.go +++ b/plugins/logging/operations_test.go @@ -243,3 +243,407 @@ func TestUpdateStreamingLogEntryPreservesResponsesInputContentSummary(t *testing t.Fatalf("expected content summary to avoid overwriting with streamed responses output-only data, got %q", logEntry.ContentSummary) } } + +func TestStoreOrEnqueueRetryPreservesAllEntries(t *testing.T) { + // Simulate fallback/retry scenario where multiple PostLLMHook calls + // store entries under the same traceID. All entries must be preserved. + plugin := &LoggerPlugin{ + logger: testLogger{}, + writeQueue: make(chan *writeQueueEntry, 10), + } + + traceID := "trace-retry-test" + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) + + // Simulate 3 retry attempts storing entries under the same traceID + entry1 := &logstore.Log{ID: "req-attempt-1", Model: "gpt-4o"} + entry2 := &logstore.Log{ID: "req-attempt-2", Model: "gpt-4o"} + entry3 := &logstore.Log{ID: "req-attempt-3", Model: "claude-3-5-sonnet"} + + plugin.storeOrEnqueueEntry(ctx, entry1, nil) + plugin.storeOrEnqueueEntry(ctx, entry2, nil) + plugin.storeOrEnqueueEntry(ctx, entry3, nil) + + // Verify all 3 entries are stored + val, ok := plugin.pendingLogsToInject.Load(traceID) + if !ok { + t.Fatal("expected pending entries for traceID, got none") + } + pending, ok := val.(*pendingInjectEntries) + if !ok { + t.Fatal("expected *pendingInjectEntries type") + } + if len(pending.entries) != 3 { + t.Fatalf("expected 3 entries, got %d", len(pending.entries)) + } + if pending.entries[0].ID != "req-attempt-1" || pending.entries[1].ID != "req-attempt-2" || pending.entries[2].ID != "req-attempt-3" { + t.Fatalf("entries not in expected order: %v, %v, %v", pending.entries[0].ID, pending.entries[1].ID, pending.entries[2].ID) + } + + // Now test Inject flushes all entries with plugin logs attached + trace := &schemas.Trace{ + TraceID: traceID, + PluginLogs: []schemas.PluginLogEntry{ + {PluginName: "hello-world", Level: schemas.LogLevelInfo, Message: "test log"}, + }, + } + + if err := plugin.Inject(context.Background(), trace); err != nil { + t.Fatalf("Inject() error = %v", err) + } + + // Verify all 3 entries were enqueued to writeQueue + if len(plugin.writeQueue) != 3 { + t.Fatalf("expected 3 entries in writeQueue, got %d", len(plugin.writeQueue)) + } + + // Verify plugin logs were attached to each entry + for i := 0; i < 3; i++ { + qe := <-plugin.writeQueue + if qe.log.PluginLogs == "" { + t.Fatalf("entry %d: expected PluginLogs to be set", i) + } + } + + // Verify pendingLogsToInject was cleaned up + if _, ok := plugin.pendingLogsToInject.Load(traceID); ok { + t.Fatal("expected pendingLogsToInject to be cleaned up after Inject") + } +} + +func TestApplyRealtimeOutputToEntryBackfillsUserTranscriptFromRawRequest(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Hello!" + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: `{"type":"conversation.item.input_audio_transcription.completed","transcript":"Hello."}`, + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 1 { + t.Fatalf("len(InputHistoryParsed) = %d, want 1", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "Hello." { + t.Fatalf("InputHistoryParsed[0] = %+v, want transcript", entry.InputHistoryParsed[0]) + } + if entry.OutputMessageParsed == nil || entry.OutputMessageParsed.Content == nil || entry.OutputMessageParsed.Content.ContentStr == nil || *entry.OutputMessageParsed.Content.ContentStr != assistantText { + t.Fatalf("OutputMessageParsed = %+v, want assistant text", entry.OutputMessageParsed) + } + if !strings.Contains(entry.ContentSummary, "Hello.") { + t.Fatalf("ContentSummary = %q, want user transcript", entry.ContentSummary) + } + if !strings.Contains(entry.ContentSummary, "Hello!") { + t.Fatalf("ContentSummary = %q, want assistant text", entry.ContentSummary) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsMissingTranscriptPlaceholder(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Hi there!" + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}`, + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 1 { + t.Fatalf("len(InputHistoryParsed) = %d, want 1", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != realtimeMissingTranscriptText { + t.Fatalf("InputHistoryParsed[0] = %+v, want missing transcript placeholder", entry.InputHistoryParsed[0]) + } + if !strings.Contains(entry.ContentSummary, realtimeMissingTranscriptText) { + t.Fatalf("ContentSummary = %q, want missing transcript placeholder", entry.ContentSummary) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsDoneMissingTranscriptPlaceholder(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Hi there!" + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: `{"type":"conversation.item.done","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_audio","transcript":null}]}}`, + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 1 { + t.Fatalf("len(InputHistoryParsed) = %d, want 1", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != realtimeMissingTranscriptText { + t.Fatalf("InputHistoryParsed[0] = %+v, want missing transcript placeholder", entry.InputHistoryParsed[0]) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsRetrievedUserAndToolHistory(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "I checked that for you." + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.retrieved","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_text","text":"Where is my order?"}]}}`, + `{"type":"conversation.item.retrieved","item":{"id":"item_tool","type":"function_call_output","call_id":"call_123","status":"completed","output":"{\"status\":\"delivered\"}"}}`, + }, "\n\n"), + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 2 { + t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "Where is my order?" { + t.Fatalf("InputHistoryParsed[0] = %+v, want user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].Role != schemas.ChatMessageRoleTool { + t.Fatalf("InputHistoryParsed[1].Role = %q, want tool", entry.InputHistoryParsed[1].Role) + } + if entry.InputHistoryParsed[1].Content == nil || entry.InputHistoryParsed[1].Content.ContentStr == nil || *entry.InputHistoryParsed[1].Content.ContentStr != `{"status":"delivered"}` { + t.Fatalf("InputHistoryParsed[1] = %+v, want tool content", entry.InputHistoryParsed[1]) + } + if entry.InputHistoryParsed[1].ChatToolMessage == nil || entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID != "call_123" { + t.Fatalf("InputHistoryParsed[1].ChatToolMessage = %+v, want tool call id", entry.InputHistoryParsed[1].ChatToolMessage) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsCreatedUserAndToolHistory(t *testing.T) { + t.Parallel() + + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + ExtraFields: schemas.BifrostResponseExtraFields{ + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.created","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_text","text":"I need help"}]}}`, + `{"type":"conversation.item.created","item":{"id":"item_tool","type":"function_call_output","call_id":"call_456","status":"completed","output":"{\"status\":\"ok\"}"}}`, + }, "\n\n"), + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + + if len(entry.InputHistoryParsed) != 2 { + t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "I need help" { + t.Fatalf("InputHistoryParsed[0] = %+v, want user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].Role != schemas.ChatMessageRoleTool { + t.Fatalf("InputHistoryParsed[1].Role = %q, want tool", entry.InputHistoryParsed[1].Role) + } + if entry.InputHistoryParsed[1].Content == nil || entry.InputHistoryParsed[1].Content.ContentStr == nil || *entry.InputHistoryParsed[1].Content.ContentStr != `{"status":"ok"}` { + t.Fatalf("InputHistoryParsed[1] = %+v, want tool content", entry.InputHistoryParsed[1]) + } + if entry.InputHistoryParsed[1].ChatToolMessage == nil || entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID != "call_456" { + t.Fatalf("InputHistoryParsed[1].ChatToolMessage = %+v, want tool call id", entry.InputHistoryParsed[1].ChatToolMessage) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsAddedUserAndToolHistory(t *testing.T) { + t.Parallel() + + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Done." + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.added","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_text","text":"hello from added item"}]}}`, + `{"type":"conversation.item.added","item":{"id":"item_tool","type":"function_call_output","call_id":"call_added","status":"completed","output":"{\"status\":\"ok\"}"}}`, + }, "\n\n"), + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 2 { + t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "hello from added item" { + t.Fatalf("InputHistoryParsed[0] = %+v, want added user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].ChatToolMessage == nil || entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID != "call_added" { + t.Fatalf("InputHistoryParsed[1].ChatToolMessage = %+v, want added tool call id", entry.InputHistoryParsed[1].ChatToolMessage) + } +} + +func TestApplyRealtimeOutputToEntryMergesRawTranscriptIntoStructuredRealtimeHistory(t *testing.T) { + t.Parallel() + + plugin := &LoggerPlugin{} + entry := &logstore.Log{ + InputHistoryParsed: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Can you help with my ticket?"), + }, + }, + { + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(`{"status":"open"}`), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call_789"), + }, + }, + }, + } + + assistantText := "Let me check." + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.input_audio_transcription.completed","transcript":"Hello."}`, + `{"type":"conversation.item.retrieved","item":{"id":"item_tool","type":"function_call_output","call_id":"call_789","status":"completed","output":"{\"status\":\"open\"}"}}`, + }, "\n\n"), + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 3 { + t.Fatalf("len(InputHistoryParsed) = %d, want 3", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "Can you help with my ticket?" { + t.Fatalf("InputHistoryParsed[0] = %+v, want structured user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[1].Role = %q, want user", entry.InputHistoryParsed[1].Role) + } + if entry.InputHistoryParsed[1].Content == nil || entry.InputHistoryParsed[1].Content.ContentStr == nil || *entry.InputHistoryParsed[1].Content.ContentStr != "Hello." { + t.Fatalf("InputHistoryParsed[1] = %+v, want raw transcript merge", entry.InputHistoryParsed[1]) + } + if entry.InputHistoryParsed[2].Role != schemas.ChatMessageRoleTool { + t.Fatalf("InputHistoryParsed[2].Role = %q, want tool", entry.InputHistoryParsed[2].Role) + } + if entry.InputHistoryParsed[2].ChatToolMessage == nil || entry.InputHistoryParsed[2].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[2].ChatToolMessage.ToolCallID != "call_789" { + t.Fatalf("InputHistoryParsed[2].ChatToolMessage = %+v, want original tool call id", entry.InputHistoryParsed[2].ChatToolMessage) + } + if strings.Count(entry.ContentSummary, "Hello.") != 1 { + t.Fatalf("ContentSummary = %q, want one merged transcript", entry.ContentSummary) + } +} diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 80bc953e99..4d1abbbde5 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -8,6 +8,7 @@ import ( "strings" "time" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/streaming" @@ -27,6 +28,12 @@ type LogManager interface { // Search searches for log entries based on filters and pagination Search(ctx context.Context, filters *logstore.SearchFilters, pagination *logstore.PaginationOptions) (*logstore.SearchResult, error) + // GetSessionLogs returns paginated logs for a single parent_request_id session. + GetSessionLogs(ctx context.Context, sessionID string, pagination *logstore.PaginationOptions) (*logstore.SessionDetailResult, error) + + // GetSessionSummary returns aggregate totals for a single parent_request_id session. + GetSessionSummary(ctx context.Context, sessionID string) (*logstore.SessionSummaryResult, error) + // GetStats calculates statistics for logs matching the given filters GetStats(ctx context.Context, filters *logstore.SearchFilters) (*logstore.SearchStats, error) @@ -63,6 +70,9 @@ type LogManager interface { // GetAvailableModels returns all unique models from logs GetAvailableModels(ctx context.Context) []string + // GetAvailableAliases returns all unique alias values from logs + GetAvailableAliases(ctx context.Context) []string + // GetAvailableSelectedKeys returns all unique selected key ID-Name pairs from logs GetAvailableSelectedKeys(ctx context.Context) []KeyPair @@ -75,9 +85,30 @@ type LogManager interface { // GetAvailableRoutingEngines returns all unique routing engine types from logs GetAvailableRoutingEngines(ctx context.Context) []string + // GetAvailableTeams returns all unique team ID-Name pairs from logs + GetAvailableTeams(ctx context.Context) []KeyPair + + // GetAvailableCustomers returns all unique customer ID-Name pairs from logs + GetAvailableCustomers(ctx context.Context) []KeyPair + + // GetAvailableUsers returns all unique user IDs from logs + GetAvailableUsers(ctx context.Context) []KeyPair + + // GetAvailableBusinessUnits returns all unique business unit ID-Name pairs from logs + GetAvailableBusinessUnits(ctx context.Context) []KeyPair + // GetAvailableMetadataKeys returns distinct metadata keys and their values from recent logs GetAvailableMetadataKeys(ctx context.Context) (map[string][]string, error) + // GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension + GetDimensionCostHistogram(ctx context.Context, filters *logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionCostHistogramResult, error) + + // GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension + GetDimensionTokenHistogram(ctx context.Context, filters *logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionTokenHistogramResult, error) + + // GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension + GetDimensionLatencyHistogram(ctx context.Context, filters *logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionLatencyHistogramResult, error) + // DeleteLog deletes a log entry by its ID DeleteLog(ctx context.Context, id string) error @@ -132,6 +163,23 @@ func (p *PluginLogManager) Search(ctx context.Context, filters *logstore.SearchF return p.plugin.SearchLogs(ctx, *filters, *pagination) } +func (p *PluginLogManager) GetSessionLogs(ctx context.Context, sessionID string, pagination *logstore.PaginationOptions) (*logstore.SessionDetailResult, error) { + if pagination == nil { + return nil, fmt.Errorf("pagination cannot be nil") + } + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + return p.plugin.GetSessionLogs(ctx, sessionID, *pagination) +} + +func (p *PluginLogManager) GetSessionSummary(ctx context.Context, sessionID string) (*logstore.SessionSummaryResult, error) { + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + return p.plugin.GetSessionSummary(ctx, sessionID) +} + func (p *PluginLogManager) GetStats(ctx context.Context, filters *logstore.SearchFilters) (*logstore.SearchStats, error) { if filters == nil { return nil, fmt.Errorf("filters cannot be nil") @@ -211,6 +259,11 @@ func (p *PluginLogManager) GetAvailableModels(ctx context.Context) []string { return p.plugin.GetAvailableModels(ctx) } +// GetAvailableAliases returns all unique alias values from logs +func (p *PluginLogManager) GetAvailableAliases(ctx context.Context) []string { + return p.plugin.GetAvailableAliases(ctx) +} + // GetAvailableSelectedKeys returns all unique selected key ID-Name pairs from logs func (p *PluginLogManager) GetAvailableSelectedKeys(ctx context.Context) []KeyPair { return p.plugin.GetAvailableSelectedKeys(ctx) @@ -231,6 +284,50 @@ func (p *PluginLogManager) GetAvailableRoutingEngines(ctx context.Context) []str return p.plugin.GetAvailableRoutingEngines(ctx) } +// GetAvailableTeams returns all unique team ID-Name pairs from logs. +func (p *PluginLogManager) GetAvailableTeams(ctx context.Context) []KeyPair { + return p.plugin.GetAvailableTeams(ctx) +} + +// GetAvailableCustomers returns all unique customer ID-Name pairs from logs. +func (p *PluginLogManager) GetAvailableCustomers(ctx context.Context) []KeyPair { + return p.plugin.GetAvailableCustomers(ctx) +} + +// GetAvailableUsers returns all unique user IDs from logs. +func (p *PluginLogManager) GetAvailableUsers(ctx context.Context) []KeyPair { + return p.plugin.GetAvailableUsers(ctx) +} + +// GetAvailableBusinessUnits returns all unique business unit ID-Name pairs from logs. +func (p *PluginLogManager) GetAvailableBusinessUnits(ctx context.Context) []KeyPair { + return p.plugin.GetAvailableBusinessUnits(ctx) +} + +// GetDimensionCostHistogram returns time-bucketed cost data grouped by the specified dimension. +func (p *PluginLogManager) GetDimensionCostHistogram(ctx context.Context, filters *logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionCostHistogramResult, error) { + if filters == nil { + return nil, fmt.Errorf("filters cannot be nil") + } + return p.plugin.GetDimensionCostHistogram(ctx, *filters, bucketSizeSeconds, dimension) +} + +// GetDimensionTokenHistogram returns time-bucketed token usage grouped by the specified dimension. +func (p *PluginLogManager) GetDimensionTokenHistogram(ctx context.Context, filters *logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionTokenHistogramResult, error) { + if filters == nil { + return nil, fmt.Errorf("filters cannot be nil") + } + return p.plugin.GetDimensionTokenHistogram(ctx, *filters, bucketSizeSeconds, dimension) +} + +// GetDimensionLatencyHistogram returns time-bucketed latency percentiles grouped by the specified dimension. +func (p *PluginLogManager) GetDimensionLatencyHistogram(ctx context.Context, filters *logstore.SearchFilters, bucketSizeSeconds int64, dimension logstore.HistogramDimension) (*logstore.DimensionLatencyHistogramResult, error) { + if filters == nil { + return nil, fmt.Errorf("filters cannot be nil") + } + return p.plugin.GetDimensionLatencyHistogram(ctx, *filters, bucketSizeSeconds, dimension) +} + func (p *PluginLogManager) GetAvailableMetadataKeys(ctx context.Context) (map[string][]string, error) { if p.plugin == nil || p.plugin.store == nil { return map[string][]string{}, nil @@ -378,6 +475,9 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s if request.ChatRequest != nil { return request.ChatRequest.Input, []schemas.ResponsesMessage{} } + if request.RequestType == schemas.RealtimeRequest && request.ResponsesRequest != nil { + return extractRealtimeInputHistory(request.ResponsesRequest.Input), []schemas.ResponsesMessage{} + } if request.ResponsesRequest != nil && len(request.ResponsesRequest.Input) > 0 { return []schemas.ChatMessage{}, request.ResponsesRequest.Input } @@ -451,6 +551,96 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s return []schemas.ChatMessage{}, []schemas.ResponsesMessage{} } +func extractRealtimeInputHistory(input []schemas.ResponsesMessage) []schemas.ChatMessage { + messages := make([]schemas.ChatMessage, 0, len(input)) + for _, item := range input { + if item.Type == nil { + continue + } + switch *item.Type { + case schemas.ResponsesMessageTypeMessage: + if item.Role == nil || item.Content == nil { + continue + } + content := extractRealtimeResponsesContent(item.Content) + if content == "" { + continue + } + messages = append(messages, schemas.ChatMessage{ + Role: mapRealtimeResponsesRole(*item.Role), + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + }) + case schemas.ResponsesMessageTypeFunctionCallOutput, + schemas.ResponsesMessageTypeCustomToolCallOutput, + schemas.ResponsesMessageTypeLocalShellCallOutput, + schemas.ResponsesMessageTypeComputerCallOutput: + content := extractRealtimeToolOutputContent(item.ResponsesToolMessage) + if content == "" { + continue + } + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: item.ResponsesToolMessage.CallID, + }, + }) + } + } + return messages +} + +func mapRealtimeResponsesRole(role schemas.ResponsesMessageRoleType) schemas.ChatMessageRole { + switch role { + case schemas.ResponsesInputMessageRoleAssistant: + return schemas.ChatMessageRoleAssistant + case schemas.ResponsesInputMessageRoleSystem: + return schemas.ChatMessageRoleSystem + case schemas.ResponsesInputMessageRoleDeveloper: + return schemas.ChatMessageRoleDeveloper + default: + return schemas.ChatMessageRoleUser + } +} + +func extractRealtimeResponsesContent(content *schemas.ResponsesMessageContent) string { + if content == nil { + return "" + } + if content.ContentStr != nil { + return strings.TrimSpace(*content.ContentStr) + } + parts := make([]string, 0, len(content.ContentBlocks)) + for _, block := range content.ContentBlocks { + switch { + case block.Text != nil && strings.TrimSpace(*block.Text) != "": + parts = append(parts, strings.TrimSpace(*block.Text)) + case block.ResponsesOutputMessageContentRefusal != nil && strings.TrimSpace(block.Refusal) != "": + parts = append(parts, strings.TrimSpace(block.Refusal)) + } + } + return strings.TrimSpace(strings.Join(parts, "\n")) +} + +func extractRealtimeToolOutputContent(toolMessage *schemas.ResponsesToolMessage) string { + if toolMessage == nil || toolMessage.Output == nil { + return "" + } + switch { + case toolMessage.Output.ResponsesToolCallOutputStr != nil: + return strings.TrimSpace(*toolMessage.Output.ResponsesToolCallOutputStr) + case len(toolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0: + content := &schemas.ResponsesMessageContent{ContentBlocks: toolMessage.Output.ResponsesFunctionToolCallOutputBlocks} + return extractRealtimeResponsesContent(content) + default: + return "" + } +} + // convertToProcessedStreamResponse converts a StreamAccumulatorResult to ProcessedStreamResponse // for use with the logging plugin's streaming log update functionality. func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, requestType schemas.RequestType) *streaming.ProcessedStreamResponse { @@ -480,7 +670,7 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r // Build accumulated data data := &streaming.AccumulatedData{ RequestID: result.RequestID, - Model: result.Model, + Model: result.RequestedModel, Status: result.Status, Stream: true, Latency: result.Latency, @@ -503,11 +693,12 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r } resp := &streaming.ProcessedStreamResponse{ - RequestID: result.RequestID, - StreamType: streamType, - Provider: result.Provider, - Model: result.Model, - Data: data, + RequestID: result.RequestID, + StreamType: streamType, + Provider: result.Provider, + RequestedModel: result.RequestedModel, + ResolvedModel: result.ResolvedModel, + Data: data, } if result.RawRequest != nil { @@ -518,6 +709,32 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r return resp } +func mergeRealtimeMetadata(metadata map[string]interface{}, ctx *schemas.BifrostContext) map[string]interface{} { + if ctx == nil { + return metadata + } + set := func(key string, ctxKey schemas.BifrostContextKey) { + if value := bifrost.GetStringFromContext(ctx, ctxKey); value != "" { + if metadata == nil { + metadata = make(map[string]interface{}) + } + metadata[key] = value + } + } + + set("realtime_session_id", schemas.BifrostContextKeyRealtimeSessionID) + set("provider_session_id", schemas.BifrostContextKeyRealtimeProviderSessionID) + set("realtime_source", schemas.BifrostContextKeyRealtimeSource) + set("realtime_event_type", schemas.BifrostContextKeyRealtimeEventType) + if bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRealtimeSessionID) != "" { + if metadata == nil { + metadata = make(map[string]interface{}) + } + metadata["realtime"] = true + } + return metadata +} + // formatRoutingEngineLogs formats routing engine logs into a human-readable string. // Format: [timestamp] [engine] - message // Parameters: diff --git a/plugins/logging/version b/plugins/logging/version index 62f0c2cadb..8e03717dca 100644 --- a/plugins/logging/version +++ b/plugins/logging/version @@ -1 +1 @@ -1.4.36 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/plugins/logging/writer.go b/plugins/logging/writer.go index e2a5861655..61afcdb792 100644 --- a/plugins/logging/writer.go +++ b/plugins/logging/writer.go @@ -1,6 +1,7 @@ package logging import ( + "sync" "time" "github.com/maximhq/bifrost/framework/logstore" @@ -36,6 +37,14 @@ type PendingLogData struct { CreatedAt time.Time // For cleanup of stale entries } +// pendingInjectEntries wraps a slice of log entries so it can be used with sync.Map. +// The mutex protects concurrent appends to the entries slice within the same traceID. +type pendingInjectEntries struct { + mu sync.Mutex + entries []*logstore.Log + createdAt time.Time +} + // writeQueueEntry is an entry pushed to the batch write queue. type writeQueueEntry struct { log *logstore.Log // Complete log entry ready for INSERT @@ -167,10 +176,18 @@ func (p *LoggerPlugin) processBatch(batch []*writeQueueEntry) { // never fires for a request (e.g., request was cancelled before reaching the provider). func (p *LoggerPlugin) cleanupStalePendingLogs() { cutoff := time.Now().Add(-pendingLogTTL) - p.pendingLogs.Range(func(key, value any) bool { + p.pendingLogsEntries.Range(func(key, value any) bool { if pending, ok := value.(*PendingLogData); ok { if pending.CreatedAt.Before(cutoff) { - p.pendingLogs.Delete(key) + p.pendingLogsEntries.Delete(key) + } + } + return true + }) + p.pendingLogsToInject.Range(func(key, value any) bool { + if pending, ok := value.(*pendingInjectEntries); ok { + if pending.createdAt.Before(cutoff) { + p.pendingLogsToInject.Delete(key) } } return true @@ -244,7 +261,7 @@ func estimateLogEntrySize(log *logstore.Log) int { len(log.PassthroughRequestBody) + len(log.PassthroughResponseBody) + len(log.ContentSummary) + - len(log.CacheDebug) + + len(log.CacheDebug) + len(log.RoutingEngineLogs) // Baseline for fixed-width columns and struct overhead return n + 512 @@ -298,6 +315,8 @@ func buildCompleteLogEntryFromPending(pending *PendingLogData) *logstore.Log { SpeechInputParsed: pending.InitialData.SpeechInput, TranscriptionInputParsed: pending.InitialData.TranscriptionInput, ImageGenerationInputParsed: pending.InitialData.ImageGenerationInput, + ImageEditInputParsed: pending.InitialData.ImageEditInput, + ImageVariationInputParsed: pending.InitialData.ImageVariationInput, PassthroughRequestBody: pending.InitialData.PassthroughRequestBody, } if pending.ParentRequestID != "" { @@ -309,12 +328,33 @@ func buildCompleteLogEntryFromPending(pending *PendingLogData) *logstore.Log { return entry } +// applyModelAlias sets entry.Model to resolvedModel (falling back to requestedModel if empty) +// and entry.Alias to requestedModel when the two differ (i.e. an alias mapping was applied). +func applyModelAlias(entry *logstore.Log, requestedModel, resolvedModel string) { + if resolvedModel != "" && resolvedModel != requestedModel { + entry.Model = resolvedModel + entry.Alias = &requestedModel + } else { + // No alias mapping; keep whichever value is non-empty as the model. + if resolvedModel != "" { + entry.Model = resolvedModel + } else if requestedModel != "" { + entry.Model = requestedModel + } + entry.Alias = nil + } +} + // applyOutputFieldsToEntry sets common output fields on a log entry. func applyOutputFieldsToEntry( entry *logstore.Log, selectedKeyID, selectedKeyName string, virtualKeyID, virtualKeyName string, routingRuleID, routingRuleName string, + teamID, teamName string, + customerID, customerName string, + userID string, + businessUnitID, businessUnitName string, numberOfRetries int, latency int64, ) { @@ -332,6 +372,27 @@ func applyOutputFieldsToEntry( if routingRuleName != "" { entry.RoutingRuleName = &routingRuleName } + if teamID != "" { + entry.TeamID = &teamID + } + if teamName != "" { + entry.TeamName = &teamName + } + if customerID != "" { + entry.CustomerID = &customerID + } + if customerName != "" { + entry.CustomerName = &customerName + } + if userID != "" { + entry.UserID = &userID + } + if businessUnitID != "" { + entry.BusinessUnitID = &businessUnitID + } + if businessUnitName != "" { + entry.BusinessUnitName = &businessUnitName + } if numberOfRetries != 0 { entry.NumberOfRetries = numberOfRetries } diff --git a/plugins/maxim/changelog.md b/plugins/maxim/changelog.md index e69de29bb2..d22c95cca2 100644 --- a/plugins/maxim/changelog.md +++ b/plugins/maxim/changelog.md @@ -0,0 +1,4 @@ +- feat: add per-user OAuth consent flow with identity selection and MCP authentication +- feat: add support for image generation requests +- feat: add realtime turn logging +- feat: add support for tracking userId, teamId, customerId, and businessUnitId diff --git a/plugins/maxim/go.mod b/plugins/maxim/go.mod index 5aa8ef0ce8..1df6829252 100644 --- a/plugins/maxim/go.mod +++ b/plugins/maxim/go.mod @@ -3,8 +3,8 @@ module github.com/maximhq/bifrost/plugins/maxim go 1.26.1 require ( - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 github.com/maximhq/maxim-go v0.2.1 ) diff --git a/plugins/maxim/go.sum b/plugins/maxim/go.sum index 3f884c361b..62dc9883b5 100644 --- a/plugins/maxim/go.sum +++ b/plugins/maxim/go.sum @@ -193,10 +193,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= github.com/maximhq/maxim-go v0.2.1 h1:hCp8dQ4HsyyNC+y5HCUuY/HFD0sOnGkjL5MdYCHkgEQ= github.com/maximhq/maxim-go v0.2.1/go.mod h1:nwFznXy0Dn4mxXGU4X+BCnE3VP68L+FPEaW0yUgk96o= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index 2d0044bc19..ed8f205dd0 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -119,10 +119,11 @@ func convertAccResultToProcessedStreamResponse(accResult *schemas.StreamAccumula streamType = streaming.StreamTypeImage } return &streaming.ProcessedStreamResponse{ - RequestID: accResult.RequestID, - StreamType: streamType, - Model: accResult.Model, - Provider: accResult.Provider, + RequestID: accResult.RequestID, + StreamType: streamType, + RequestedModel: accResult.RequestedModel, + ResolvedModel: accResult.ResolvedModel, + Provider: accResult.Provider, Data: &streaming.AccumulatedData{ Status: accResult.Status, Latency: accResult.Latency, @@ -246,6 +247,10 @@ func (plugin *Plugin) getOrCreateLogger(logRepoID string) (*logging.Logger, erro // - *schemas.BifrostRequest: The original request, unmodified // - error: Any error that occurred during trace/generation creation func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + if req != nil && req.RequestType == schemas.RealtimeRequest { + return req, nil, nil + } + var traceID string var traceName string var sessionID string @@ -522,6 +527,11 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro // - *schemas.BifrostError: The original error, unmodified // - error: Never returns an error as it handles missing IDs gracefully func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + requestType, _, _, _ := bifrost.GetResponseFields(result, bifrostErr) + if requestType == schemas.RealtimeRequest { + return result, bifrostErr, nil + } + // Get effective log repo ID for this request effectiveLogRepoID := plugin.getEffectiveLogRepoID(ctx) if effectiveLogRepoID == "" { @@ -545,7 +555,11 @@ func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.B isFinalChunk := bifrost.IsFinalChunk(ctx) go func() { - requestType, _, model := bifrost.GetResponseFields(result, bifrostErr) + requestType, _, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) + modelTag := resolvedModel + if modelTag == "" { + modelTag = originalModel + } var streamResponse *streaming.ProcessedStreamResponse if bifrost.IsStreamRequestType(requestType) { @@ -650,11 +664,11 @@ func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.B } } } - if hasGenerationID && generationID != "" { - logger.AddTagToGeneration(generationID, "model", string(model)) + if hasGenerationID && generationID != "" && modelTag != "" { + logger.AddTagToGeneration(generationID, "model", string(modelTag)) } - if hasTraceID && traceID != "" { - logger.AddTagToTrace(traceID, "model", string(model)) + if hasTraceID && traceID != "" && modelTag != "" { + logger.AddTagToTrace(traceID, "model", string(modelTag)) } // Flush only the effective logger that was used for this request logger.Flush() diff --git a/plugins/maxim/version b/plugins/maxim/version index 3367256a75..2eda823ff5 100644 --- a/plugins/maxim/version +++ b/plugins/maxim/version @@ -1 +1 @@ -1.5.36 \ No newline at end of file +1.6.1 \ No newline at end of file diff --git a/plugins/mocker/changelog.md b/plugins/mocker/changelog.md index e69de29bb2..9d094203da 100644 --- a/plugins/mocker/changelog.md +++ b/plugins/mocker/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.1 and framework to v1.3.1 diff --git a/plugins/mocker/go.mod b/plugins/mocker/go.mod index 337eb42d44..7f13a83bf1 100644 --- a/plugins/mocker/go.mod +++ b/plugins/mocker/go.mod @@ -4,7 +4,7 @@ go 1.26.1 require ( github.com/jaswdr/faker/v2 v2.8.0 - github.com/maximhq/bifrost/core v1.4.17 + github.com/maximhq/bifrost/core v1.5.1 ) require ( diff --git a/plugins/mocker/go.sum b/plugins/mocker/go.sum index e0c4eedb4d..f6778ae42a 100644 --- a/plugins/mocker/go.sum +++ b/plugins/mocker/go.sum @@ -111,8 +111,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index 29189c35cb..d9ccb765f4 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -853,7 +853,7 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: req.RequestType, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, Latency: int64(time.Since(startTime).Milliseconds()), }, } @@ -877,7 +877,7 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesRequest, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, Latency: int64(time.Since(startTime).Milliseconds()), }, } @@ -905,7 +905,7 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ResponsesStreamRequest, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, Latency: int64(time.Since(startTime).Milliseconds()), }, } @@ -959,7 +959,7 @@ func (p *MockerPlugin) generateErrorShortCircuit(req *schemas.BifrostRequest, re ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: req.RequestType, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, }, } @@ -1083,7 +1083,7 @@ func (p *MockerPlugin) handleDefaultBehavior(req *schemas.BifrostRequest) (*sche ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ChatCompletionRequest, Provider: provider, - ModelRequested: model, + OriginalModelRequested: model, }, }, }, diff --git a/plugins/mocker/version b/plugins/mocker/version index de17646bc0..8e03717dca 100644 --- a/plugins/mocker/version +++ b/plugins/mocker/version @@ -1 +1 @@ -1.4.35 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/plugins/otel/changelog.md b/plugins/otel/changelog.md index e69de29bb2..9d094203da 100644 --- a/plugins/otel/changelog.md +++ b/plugins/otel/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.1 and framework to v1.3.1 diff --git a/plugins/otel/go.mod b/plugins/otel/go.mod index 2dce8d732e..675b93da45 100644 --- a/plugins/otel/go.mod +++ b/plugins/otel/go.mod @@ -3,8 +3,8 @@ module github.com/maximhq/bifrost/plugins/otel go 1.26.1 require ( - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 go.opentelemetry.io/otel v1.40.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.40.0 diff --git a/plugins/otel/go.sum b/plugins/otel/go.sum index 9e7565d3e9..0694a56918 100644 --- a/plugins/otel/go.sum +++ b/plugins/otel/go.sum @@ -197,10 +197,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/plugins/otel/version b/plugins/otel/version index 98924abee9..cb174d58a5 100644 --- a/plugins/otel/version +++ b/plugins/otel/version @@ -1 +1 @@ -1.1.35 \ No newline at end of file +1.2.1 \ No newline at end of file diff --git a/plugins/prompts/changelog.md b/plugins/prompts/changelog.md new file mode 100644 index 0000000000..f16b46ac0c --- /dev/null +++ b/plugins/prompts/changelog.md @@ -0,0 +1,3 @@ +- feat: add prompts plugin with direct key header resolver +- feat: add per-user OAuth consent flow with identity selection and MCP authentication +- feat: add selective message inclusion when committing prompt sessions diff --git a/plugins/prompts/go.mod b/plugins/prompts/go.mod new file mode 100644 index 0000000000..d293e006f8 --- /dev/null +++ b/plugins/prompts/go.mod @@ -0,0 +1,79 @@ +module github.com/maximhq/bifrost/plugins/prompts + +go 1.26.1 + +require ( + github.com/maximhq/bifrost/core v1.4.13 + github.com/maximhq/bifrost/framework v1.2.32 + github.com/stretchr/testify v1.11.1 +) + +require ( + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.11 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect + github.com/aws/smithy-go v1.24.2 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.2 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.starlark.net v0.0.0-20260102030733-3fee463870c9 // indirect + golang.org/x/arch v0.23.0 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/gorm v1.31.1 // indirect +) diff --git a/plugins/prompts/go.sum b/plugins/prompts/go.sum new file mode 100644 index 0000000000..adfeae374c --- /dev/null +++ b/plugins/prompts/go.sum @@ -0,0 +1,209 @@ +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 h1:N4lRUXZpZ1KVEUn6hxtco/1d2lgYhNn1fHkkl8WhlyQ= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= +github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 h1:CjMzUs78RDDv4ROu3JnJn/Ig1r6ZD7/T2DXLLRpejic= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16/go.mod h1:uVW4OLBqbJXSHJYA9svT9BluSvvwbzLQ2Crf6UPzR3c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 h1:DIBqIrJ7hv+e4CmIk2z3pyKT+3B6qVMgRsawHiR3qso= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7/go.mod h1:vLm00xmBke75UmpNvOcZQ/Q30ZFjbczeLFqGx5urmGo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 h1:SWTxh/EcUCDVqi/0s26V6pVUq0BBG7kx0tDTmF/hCgA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7/go.mod h1:idzZ7gmDeqeNrSPkdbtMp9qWMgcBwykA7P7Rzh5DXVU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk= +github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= +github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= +github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.4.13 h1:ECCIbdgLUy+jYRXOVn3E9uYCu3mYCOh7GV4ElVjHKLU= +github.com/maximhq/bifrost/core v1.4.13/go.mod h1:Kc11vnzU8UgwBTJS+TgG8S9vuSnas+T8uYx3xwzFuIA= +github.com/maximhq/bifrost/framework v1.2.32 h1:J8xhYXM/5bOmNmpWP9avQYoPV63bQ6IoKLAl3ZvxHok= +github.com/maximhq/bifrost/framework v1.2.32/go.mod h1:8IegKP+/HGpbl1Kh7TP/CFuENPjQVUpJiuKh/u3IvXk= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 h1:qIQ0tWF9vxGtkJa24bR+2i53WBCz1nW/Pc47oVYauC4= +github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= +go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= +golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= +golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/plugins/prompts/helpers_test.go b/plugins/prompts/helpers_test.go new file mode 100644 index 0000000000..ec8f778c7b --- /dev/null +++ b/plugins/prompts/helpers_test.go @@ -0,0 +1,385 @@ +package prompts + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/maximhq/bifrost/core/schemas" + tables "github.com/maximhq/bifrost/framework/configstore/tables" +) + +// ============================================================ +// MockLogger β€” captures log output per level for assertions. +// Follows the same pattern as plugins/governance/test_utils.go. +// ============================================================ + +type MockLogger struct { + mu sync.Mutex + debugs []string + infos []string + warnings []string + errors []string +} + +func NewMockLogger() *MockLogger { + return &MockLogger{ + debugs: make([]string, 0), + infos: make([]string, 0), + warnings: make([]string, 0), + errors: make([]string, 0), + } +} + +func (l *MockLogger) Debug(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.debugs = append(l.debugs, format) +} + +func (l *MockLogger) Info(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.infos = append(l.infos, format) +} + +func (l *MockLogger) Warn(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.warnings = append(l.warnings, format) +} + +func (l *MockLogger) Error(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.errors = append(l.errors, format) +} + +func (l *MockLogger) Fatal(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.errors = append(l.errors, format) +} + +func (l *MockLogger) SetLevel(_ schemas.LogLevel) {} +func (l *MockLogger) SetOutputType(_ schemas.LoggerOutputType) {} +func (l *MockLogger) LogHTTPRequest(_ schemas.LogLevel, _ string) schemas.LogEventBuilder { + return schemas.NoopLogEvent +} + +// Warned returns true if at least one warning was logged. +func (l *MockLogger) Warned() bool { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.warnings) > 0 +} + +// ============================================================ +// mockStore β€” satisfies promptStore with controllable responses. +// ============================================================ + +type mockStore struct { + prompts []tables.TablePrompt + versions []tables.TablePromptVersion + err error +} + +func (m *mockStore) GetPrompts(_ context.Context, _ *string) ([]tables.TablePrompt, error) { + return m.prompts, m.err +} + +func (m *mockStore) GetAllPromptVersions(_ context.Context) ([]tables.TablePromptVersion, error) { + return m.versions, m.err +} + +// versionsErrStore succeeds on GetPrompts but fails on GetAllPromptVersions. +type versionsErrStore struct { + prompts []tables.TablePrompt + err error +} + +func (s *versionsErrStore) GetPrompts(_ context.Context, _ *string) ([]tables.TablePrompt, error) { + return s.prompts, nil +} + +func (s *versionsErrStore) GetAllPromptVersions(_ context.Context) ([]tables.TablePromptVersion, error) { + return nil, s.err +} + +// ============================================================ +// staticResolver β€” returns fixed IDs; decouples PreLLMHook +// tests from HTTP header / context mechanics. +// ============================================================ + +type staticResolver struct { + promptID string + versionNumber int + versionSpecified bool + err error +} + +func (r *staticResolver) Resolve(_ *schemas.BifrostContext, _ *schemas.BifrostRequest) (string, int, bool, error) { + return r.promptID, r.versionNumber, r.versionSpecified, r.err +} + +// ============================================================ +// Plugin builders +// ============================================================ + +// newPluginWithStore builds a Plugin whose store is set but maps are empty. +// Use only for loadCache tests. +func newPluginWithStore(s promptStore) *Plugin { + return &Plugin{ + store: s, + logger: NewMockLogger(), + resolver: &staticResolver{}, + promptsByID: make(map[string]*tables.TablePrompt), + versionsByPromptAndNumber: make(map[string]map[int]*tables.TablePromptVersion), + } +} + +// newTestPlugin builds a Plugin with pre-seeded in-memory maps, bypassing Init +// and loadCache entirely. The store is nil β€” safe as long as no test path calls +// into the store. +func newTestPlugin(resolver PromptResolver, promptMap map[string]*tables.TablePrompt, versionMap map[string]map[int]*tables.TablePromptVersion) *Plugin { + return newTestPluginWithLogger(resolver, promptMap, versionMap, NewMockLogger()) +} + +// newTestPluginWithLogger is like newTestPlugin but accepts a caller-provided logger +// so tests can inspect logged warnings. +func newTestPluginWithLogger(resolver PromptResolver, promptMap map[string]*tables.TablePrompt, versionMap map[string]map[int]*tables.TablePromptVersion, log schemas.Logger) *Plugin { + if resolver == nil { + resolver = &staticResolver{} + } + if promptMap == nil { + promptMap = make(map[string]*tables.TablePrompt) + } + if versionMap == nil { + versionMap = make(map[string]map[int]*tables.TablePromptVersion) + } + return &Plugin{ + store: nil, + logger: log, + resolver: resolver, + promptsByID: promptMap, + versionsByPromptAndNumber: versionMap, + } +} + +// ============================================================ +// Message builders +// ============================================================ + +// versionMsg creates a TablePromptVersionMessage in the production envelope +// format {"payload": }, matching what the frontend writes +// to the DB and what AfterFind populates into the Message field. +func versionMsg(role schemas.ChatMessageRole, text string) tables.TablePromptVersionMessage { + content := text + inner := schemas.ChatMessage{ + Role: role, + Content: &schemas.ChatMessageContent{ContentStr: &content}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsg: marshal inner failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// versionMsgViaJSON creates a TablePromptVersionMessage that has an empty Message +// field but a populated MessageJSON field, exercising the fallback branch in +// chatMessagesFromVersionMessages. +func versionMsgViaJSON(role schemas.ChatMessageRole, text string) tables.TablePromptVersionMessage { + content := text + inner := schemas.ChatMessage{ + Role: role, + Content: &schemas.ChatMessageContent{ContentStr: &content}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgViaJSON: marshal failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: nil, // empty β€” triggers MessageJSON fallback + MessageJSON: envelope, + } +} + +// makeVersion returns a TablePromptVersion with the supplied messages. +// VersionNumber is set to int(id) so tests can reference versions by their number. +func makeVersion(id uint, promptID string, isLatest bool, msgs ...tables.TablePromptVersionMessage) tables.TablePromptVersion { + return tables.TablePromptVersion{ + ID: id, + PromptID: promptID, + IsLatest: isLatest, + VersionNumber: int(id), + Messages: msgs, + } +} + +// makePrompt returns a TablePrompt, optionally linked to a latest version. +func makePrompt(id string, latest *tables.TablePromptVersion) tables.TablePrompt { + return tables.TablePrompt{ID: id, Name: id, LatestVersion: latest} +} + +// ============================================================ +// Request / context builders +// ============================================================ + +// chatRequest returns a BifrostRequest wrapping a ChatRequest with the given messages. +func chatRequest(msgs ...schemas.ChatMessage) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + ChatRequest: &schemas.BifrostChatRequest{ + Input: append([]schemas.ChatMessage{}, msgs...), + }, + } +} + +// userMsg returns a user-role ChatMessage with plain text content. +func userMsg(text string) schemas.ChatMessage { + t := text + return schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ContentStr: &t}, + } +} + +// systemMsg returns a system-role ChatMessage with plain text content. +func systemMsg(text string) schemas.ChatMessage { + t := text + return schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: &schemas.ChatMessageContent{ContentStr: &t}, + } +} + +// bfCtx returns a fresh BifrostContext with no deadline. +func bfCtx() *schemas.BifrostContext { + return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) +} + +// versionMsgWithToolCall creates a TablePromptVersionMessage for an assistant +// message that contains a single tool call (role=assistant, tool_calls=[...]). +func versionMsgWithToolCall(callID, funcName, funcArgs string) tables.TablePromptVersionMessage { + name := funcName + id := callID + inner := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + ID: &id, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &name, + Arguments: funcArgs, + }, + }, + }, + }, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgWithToolCall: marshal failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// versionMsgToolResult creates a TablePromptVersionMessage for a tool-result +// message (role=tool) with the given tool_call_id and result text. +func versionMsgToolResult(callID, result string) tables.TablePromptVersionMessage { + id := callID + inner := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ContentStr: &result}, + ChatToolMessage: &schemas.ChatToolMessage{ToolCallID: &id}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgToolResult: marshal failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// makeVersionWithParams returns a TablePromptVersion with explicit ModelParams and messages. +// VersionNumber is set to int(id) so tests can reference versions by their number. +func makeVersionWithParams(id uint, promptID string, isLatest bool, params tables.ModelParams, msgs ...tables.TablePromptVersionMessage) tables.TablePromptVersion { + return tables.TablePromptVersion{ + ID: id, + PromptID: promptID, + IsLatest: isLatest, + VersionNumber: int(id), + ModelParams: params, + Messages: msgs, + } +} + +// chatRequestWithParams returns a BifrostRequest with Params pre-set. +func chatRequestWithParams(params *schemas.ChatParameters, msgs ...schemas.ChatMessage) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + ChatRequest: &schemas.BifrostChatRequest{ + Input: append([]schemas.ChatMessage{}, msgs...), + Params: params, + }, + } +} + +// chatRequestWithModel returns a BifrostRequest with the Model field pre-set. +func chatRequestWithModel(model string, msgs ...schemas.ChatMessage) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + ChatRequest: &schemas.BifrostChatRequest{ + Model: model, + Input: append([]schemas.ChatMessage{}, msgs...), + }, + } +} + +// versionMsgAssistantUIFormat creates a TablePromptVersionMessage in the format +// the Bifrost UI writes for assistant (completion_result) messages. +// The message is nested at payload.choices[0].message, matching SerializedMessage. +func versionMsgAssistantUIFormat(text string) tables.TablePromptVersionMessage { + content := text + inner := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: &content}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgAssistantUIFormat: marshal failed: %v", err)) + } + payload := fmt.Sprintf(`{"id":"resp-1","choices":[{"index":0,"message":%s,"finish_reason":"stop"}]}`, string(innerJSON)) + envelope := fmt.Sprintf(`{"originalType":"completion_result","payload":%s}`, payload) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// ============================================================ +// errTest β€” minimal error type for test use +// ============================================================ + +type errTest string + +func (e errTest) Error() string { return string(e) } + +// ============================================================ +// Assertion helpers +// ============================================================ + +// msgText extracts the ContentStr from a ChatMessage, returning "" if absent. +func msgText(msg schemas.ChatMessage) string { + if msg.Content == nil || msg.Content.ContentStr == nil { + return "" + } + return *msg.Content.ContentStr +} diff --git a/plugins/prompts/main.go b/plugins/prompts/main.go new file mode 100644 index 0000000000..05c2507b9f --- /dev/null +++ b/plugins/prompts/main.go @@ -0,0 +1,570 @@ +package prompts + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "strconv" + "strings" + "sync" + + "github.com/maximhq/bifrost/core/schemas" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" +) + +const ( + PluginName = "prompts" + PromptIDHeader = "bf-prompt-id" + PromptVersionHeader = "bf-prompt-version" + PromptIDKey schemas.BifrostContextKey = PromptIDHeader + PromptVersionKey schemas.BifrostContextKey = PromptVersionHeader +) + +type promptStore interface { + GetPrompts(ctx context.Context, folderID *string) ([]configstoreTables.TablePrompt, error) + GetAllPromptVersions(ctx context.Context) ([]configstoreTables.TablePromptVersion, error) +} + +// PromptResolver decides which prompt and version to inject for a given request. +// Returning an empty promptID means no injection for this request. +type PromptResolver interface { + Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, versionSpecified bool, err error) +} + +// headerResolver is the default OSS resolver: reads prompt ID and version from context +// keys that were populated from HTTP headers in HTTPTransportPreHook. +type headerResolver struct { + logger schemas.Logger +} + +func (r *headerResolver) Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (string, int, bool, error) { + promptID := promptStringFromCtx(ctx, PromptIDKey) + if promptID == "" { + return "", 0, false, nil + } + versionNumber, specified, err := parsePromptVersionNumber(ctx) + if err != nil { + return "", 0, false, fmt.Errorf("invalid bifrost-prompt-version: %w", err) + } + return promptID, versionNumber, specified, nil +} + +// Plugin resolves stored prompt templates and prepends their messages to LLM requests. +type Plugin struct { + store promptStore + logger schemas.Logger + resolver PromptResolver + + mu sync.RWMutex + promptsByID map[string]*configstoreTables.TablePrompt + versionsByPromptAndNumber map[string]map[int]*configstoreTables.TablePromptVersion +} + +// Init wires the prompts plugin with the default header-based resolver. +func Init(ctx context.Context, store promptStore, logger schemas.Logger) (schemas.LLMPlugin, error) { + return InitWithResolver(ctx, store, &headerResolver{logger: logger}, logger) +} + +// InitWithResolver wires the prompts plugin with a custom resolver. +func InitWithResolver(ctx context.Context, store promptStore, resolver PromptResolver, logger schemas.Logger) (*Plugin, error) { + if store == nil { + return nil, fmt.Errorf("config store is required for prompts plugin") + } + if resolver == nil { + resolver = &headerResolver{logger: logger} + } + p := &Plugin{ + store: store, + logger: logger, + resolver: resolver, + promptsByID: make(map[string]*configstoreTables.TablePrompt), + versionsByPromptAndNumber: make(map[string]map[int]*configstoreTables.TablePromptVersion), + } + if err := p.loadCache(ctx); err != nil { + return nil, fmt.Errorf("failed to load prompts into memory: %w", err) + } + return p, nil +} + +// loadCache rebuilds the in-memory maps with exactly two DB queries: +// one for all prompts (with their latest version), one for all versions. +func (p *Plugin) loadCache(ctx context.Context) error { + prompts, err := p.store.GetPrompts(ctx, nil) + if err != nil { + return err + } + + versions, err := p.store.GetAllPromptVersions(ctx) + if err != nil { + return fmt.Errorf("loading all prompt versions: %w", err) + } + + newPrompts := make(map[string]*configstoreTables.TablePrompt, len(prompts)) + for i := range prompts { + newPrompts[prompts[i].ID] = &prompts[i] + } + + newVersionsByPromptAndNumber := make(map[string]map[int]*configstoreTables.TablePromptVersion) + for i := range versions { + v := &versions[i] + if _, ok := newVersionsByPromptAndNumber[v.PromptID]; !ok { + newVersionsByPromptAndNumber[v.PromptID] = make(map[int]*configstoreTables.TablePromptVersion) + } + newVersionsByPromptAndNumber[v.PromptID][v.VersionNumber] = v + } + + p.mu.Lock() + p.promptsByID = newPrompts + p.versionsByPromptAndNumber = newVersionsByPromptAndNumber + p.mu.Unlock() + return nil +} + +// Reload refreshes the in-memory cache from the store. Called by the HTTP handler +// after any create/update/delete operation on prompts or versions. +func (p *Plugin) Reload(ctx context.Context) error { + return p.loadCache(ctx) +} + +func (p *Plugin) GetName() string { + return PluginName +} + +func (p *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + if req == nil { + return nil, nil + } + if id := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptIDHeader)); id != "" { + ctx.SetValue(PromptIDKey, id) + } + if v := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptVersionHeader)); v != "" { + ctx.SetValue(PromptVersionKey, v) + } + p.setPromptStreamFromVersionForTransport(ctx) + return nil, nil +} + +// setPromptStreamFromVersionForTransport sets BifrostContextKeyPromptStreamRequest when +// the resolved prompt version has stream:true in its ModelParams. +func (p *Plugin) setPromptStreamFromVersionForTransport(ctx *schemas.BifrostContext) { + promptID := promptStringFromCtx(ctx, PromptIDKey) + if promptID == "" { + return + } + versionNumber, versionSpecified, err := parsePromptVersionNumber(ctx) + if err != nil { + return + } + _, version, ok := p.resolveVersion(promptID, versionNumber, versionSpecified) + if !ok || version == nil || len(version.ModelParams) == 0 { + return + } + if includesStreamInModelParams(version.ModelParams) { + ctx.SetValue(schemas.BifrostContextKeyPromptStreamRequest, true) + } +} + +func includesStreamInModelParams(mp configstoreTables.ModelParams) bool { + raw, ok := mp["stream"] + if !ok { + return true // default to true if stream is not set, this is done because for the initial version, the stream key is not present but we default to true for the initial version and show it as well on the UI. If the user toggles stream off, we set `stream: false` in the model params in db. + } + switch v := raw.(type) { + case bool: + return v + case json.Number: + if i, err := strconv.ParseInt(string(v), 10, 64); err == nil { + return i != 0 + } + b, err := strconv.ParseBool(string(v)) + return err == nil && b + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "true", "1", "yes": + return true + default: + return false + } + default: + return false + } +} + +func (p *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { + return nil +} + +func (p *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { + return chunk, nil +} + +func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + promptID, versionNumber, versionSpecified, err := p.resolver.Resolve(ctx, req) + if err != nil { + p.logger.Warn("prompts plugin: failed to resolve prompt: %v", err) + return req, nil, nil + } + if promptID == "" { + return req, nil, nil + } + + _, version, found := p.resolveVersion(promptID, versionNumber, versionSpecified) + if !found { + p.logger.Warn("prompts plugin: prompt or version not found: %s", promptID) + return req, nil, nil + } + + if version == nil { + p.logger.Warn("prompts plugin: prompt %s has no versions", promptID) + return req, nil, nil + } + + // Apply model params from the version (version params are defaults; request params win). + switch { + case req.ChatRequest != nil: + applyVersionParamsToChatRequest(version, req.ChatRequest, p.logger) + case req.ResponsesRequest != nil: + applyVersionParamsToResponsesRequest(version, req.ResponsesRequest, p.logger) + } + + template, err := chatMessagesFromVersionMessages(version.Messages) + if err != nil { + p.logger.Warn("prompts plugin: failed to parse messages for prompt %s: %v", promptID, err) + return req, nil, nil + } + if len(template) == 0 { + return req, nil, nil + } + + switch { + case req.ChatRequest != nil: + mergeChatMessages(&req.ChatRequest.Input, template) + case req.ResponsesRequest != nil: + mergeResponsesMessages(&req.ResponsesRequest.Input, template) + } + + return req, nil, nil +} + +func (p *Plugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} + +// knownSyntheticChatParamKeys are flat JSON keys that ChatParameters.UnmarshalJSON +// promotes into nested structs. They should not be treated as ExtraParams even though +// they won't appear as top-level keys in a re-marshaled ChatParameters. +var knownSyntheticChatParamKeys = map[string]struct{}{ + "reasoning_effort": {}, + "reasoning_max_tokens": {}, +} + +// buildMergedParamsMap builds a merged map[string]interface{} where version params +// serve as defaults and request params take priority. reqParamsBytes is the JSON of +// the request's standard params (ExtraParams excluded); reqExtraParams is its ExtraParams map. +func buildMergedParamsMap(versionParams configstoreTables.ModelParams, reqParamsBytes []byte, reqExtraParams map[string]interface{}) (map[string]interface{}, error) { + merged := make(map[string]interface{}, len(versionParams)) + maps.Copy(merged, versionParams) + if len(reqParamsBytes) > 0 && string(reqParamsBytes) != "null" { + var reqMap map[string]interface{} + if err := schemas.Unmarshal(reqParamsBytes, &reqMap); err != nil { + return nil, fmt.Errorf("unmarshal request params: %w", err) + } + maps.Copy(merged, reqMap) + } + maps.Copy(merged, reqExtraParams) + return merged, nil +} + +// applyVersionParamsToChatRequest applies the prompt version's ModelParams to the +// chat request. Version params are defaults; params already set in the request win. +func applyVersionParamsToChatRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostChatRequest, logger schemas.Logger) { + if len(version.ModelParams) == 0 { + return + } + + var reqParamsBytes []byte + var reqExtraParams map[string]interface{} + if req.Params != nil { + b, err := schemas.Marshal(req.Params) + if err != nil { + logger.Warn("prompts plugin: failed to marshal chat request params: %v", err) + return + } + reqParamsBytes = b + reqExtraParams = req.Params.ExtraParams + } + + merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams) + if err != nil { + logger.Warn("prompts plugin: failed to build merged chat params: %v", err) + return + } + + mergedJSON, err := schemas.Marshal(merged) + if err != nil { + logger.Warn("prompts plugin: failed to marshal merged chat params: %v", err) + return + } + + var result schemas.ChatParameters + if err := schemas.Unmarshal(mergedJSON, &result); err != nil { + logger.Warn("prompts plugin: failed to unmarshal merged chat params: %v", err) + return + } + + // Detect keys from merged that were not recognized as standard ChatParameters fields + // (i.e. they won't appear in the re-marshaled output) and put them in ExtraParams. + var recognizedMap map[string]interface{} + recognizedBytes, err := schemas.Marshal(&result) + if err != nil { + logger.Warn("prompts plugin: failed to marshal result chat params: %v", err) + return + } + if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil { + logger.Warn("prompts plugin: failed to unmarshal recognized chat params: %v", err) + return + } + for k, v := range merged { + if _, ok := recognizedMap[k]; ok { + continue + } + if _, synthetic := knownSyntheticChatParamKeys[k]; synthetic { + continue + } + if result.ExtraParams == nil { + result.ExtraParams = make(map[string]interface{}) + } + if _, alreadySet := result.ExtraParams[k]; !alreadySet { + result.ExtraParams[k] = v + } + } + + req.Params = &result +} + +// applyVersionParamsToResponsesRequest applies the prompt version's ModelParams to the +// responses request. Version params are defaults; params already set in the request win. +func applyVersionParamsToResponsesRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostResponsesRequest, logger schemas.Logger) { + if len(version.ModelParams) == 0 { + return + } + + var reqParamsBytes []byte + var reqExtraParams map[string]interface{} + if req.Params != nil { + b, err := schemas.Marshal(req.Params) + if err != nil { + logger.Warn("prompts plugin: failed to marshal responses request params: %v", err) + return + } + reqParamsBytes = b + reqExtraParams = req.Params.ExtraParams + } + + merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams) + if err != nil { + logger.Warn("prompts plugin: failed to build merged responses params: %v", err) + return + } + + mergedJSON, err := schemas.Marshal(merged) + if err != nil { + logger.Warn("prompts plugin: failed to marshal merged responses params: %v", err) + return + } + + var result schemas.ResponsesParameters + if err := schemas.Unmarshal(mergedJSON, &result); err != nil { + logger.Warn("prompts plugin: failed to unmarshal merged responses params: %v", err) + return + } + + // Detect unrecognized keys and add them to ExtraParams. + var recognizedMap map[string]interface{} + recognizedBytes, err := schemas.Marshal(&result) + if err != nil { + logger.Warn("prompts plugin: failed to marshal result responses params: %v", err) + return + } + if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil { + logger.Warn("prompts plugin: failed to unmarshal recognized responses params: %v", err) + return + } + for k, v := range merged { + if _, ok := recognizedMap[k]; ok { + continue + } + if result.ExtraParams == nil { + result.ExtraParams = make(map[string]interface{}) + } + if _, alreadySet := result.ExtraParams[k]; !alreadySet { + result.ExtraParams[k] = v + } + } + + req.Params = &result +} + +// resolveVersion centralises the map-lookup logic shared by setPromptStreamFromVersionForTransport +// and PreLLMHook. It returns the prompt and its resolved version (either the explicitly requested +// version or the prompt's latest version), plus a bool indicating whether both were found. +func (p *Plugin) resolveVersion(promptID string, versionNumber int, versionSpecified bool) ( + *configstoreTables.TablePrompt, *configstoreTables.TablePromptVersion, bool, +) { + p.mu.RLock() + defer p.mu.RUnlock() + + prompt, ok := p.promptsByID[promptID] + if !ok || prompt == nil { + return nil, nil, false + } + if !versionSpecified { + return prompt, prompt.LatestVersion, true + } + byNumber, ok := p.versionsByPromptAndNumber[promptID] + if !ok { + return nil, nil, false + } + v, found := byNumber[versionNumber] + if !found || v == nil { + return nil, nil, false + } + return prompt, v, true +} + +func (p *Plugin) Cleanup() error { + return nil +} + +func promptStringFromCtx(ctx *schemas.BifrostContext, key schemas.BifrostContextKey) string { + if v, ok := ctx.Value(key).(string); ok { + return strings.TrimSpace(v) + } + return "" +} + +func parsePromptVersionNumber(ctx *schemas.BifrostContext) (num int, specified bool, err error) { + s, ok := ctx.Value(PromptVersionKey).(string) + if !ok { + return 0, false, nil + } + s = strings.TrimSpace(s) + if s == "" { + return 0, false, nil + } + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, true, err + } + return int(n), true, nil +} + +func chatMessagePopulated(cm schemas.ChatMessage) bool { + if strings.TrimSpace(string(cm.Role)) != "" { + return true + } + if cm.Content != nil { + return true + } + if cm.Name != nil && strings.TrimSpace(*cm.Name) != "" { + return true + } + if cm.ChatToolMessage != nil { + return true + } + if cm.ChatAssistantMessage != nil { + return true + } + return false +} + +// convertVersionMessagesToChatMessages unmarshals prompt-repo JSON into ChatMessage. +func convertVersionMessagesToChatMessages(data []byte) (schemas.ChatMessage, error) { + s := strings.TrimSpace(string(data)) + if s == "" || s == "null" { + return schemas.ChatMessage{}, fmt.Errorf("empty message") + } + data = []byte(s) + + var msg struct { + OriginalType string `json:"originalType"` + Payload json.RawMessage `json:"payload"` + } + if err := schemas.Unmarshal(data, &msg); err == nil { + ps := strings.TrimSpace(string(msg.Payload)) + if ps != "" && ps != "null" { + if msg.OriginalType == "completion_result" { + var result struct { + Choices []struct { + Message *schemas.ChatMessage `json:"message"` + } `json:"choices"` + } + if err := schemas.Unmarshal([]byte(ps), &result); err == nil && + len(result.Choices) > 0 && result.Choices[0].Message != nil { + if chatMessagePopulated(*result.Choices[0].Message) { + return *result.Choices[0].Message, nil + } + } + } + + // completion_request / tool_result / legacy envelope: payload is a direct ChatMessage. + var message schemas.ChatMessage + if err := schemas.Unmarshal([]byte(ps), &message); err != nil { + return schemas.ChatMessage{}, fmt.Errorf("decoding prompt message envelope payload: %w", err) + } + if chatMessagePopulated(message) { + return message, nil + } + } + } + + var chatMessage schemas.ChatMessage + if err := schemas.Unmarshal(data, &chatMessage); err != nil { + return schemas.ChatMessage{}, err + } + return chatMessage, nil +} + +func chatMessagesFromVersionMessages(messages []configstoreTables.TablePromptVersionMessage) ([]schemas.ChatMessage, error) { + out := make([]schemas.ChatMessage, 0, len(messages)) + for i := range messages { + row := &messages[i] + data := row.Message + if len(data) == 0 && row.MessageJSON != "" { + data = []byte(row.MessageJSON) + } + cm, err := convertVersionMessagesToChatMessages(data) + if err != nil { + return nil, fmt.Errorf("stored prompt message is not valid chat JSON: %w", err) + } + out = append(out, cm) + } + return out, nil +} + +func mergeChatMessages(dest *[]schemas.ChatMessage, prefix []schemas.ChatMessage) { + if dest == nil || len(prefix) == 0 { + return + } + cur := *dest + merged := make([]schemas.ChatMessage, 0, len(prefix)+len(cur)) + merged = append(merged, prefix...) + merged = append(merged, cur...) + *dest = merged +} + +func mergeResponsesMessages(dest *[]schemas.ResponsesMessage, template []schemas.ChatMessage) { + if dest == nil || len(template) == 0 { + return + } + var prefix []schemas.ResponsesMessage + for i := range template { + prefix = append(prefix, template[i].ToResponsesMessages()...) + } + cur := *dest + merged := make([]schemas.ResponsesMessage, 0, len(prefix)+len(cur)) + merged = append(merged, prefix...) + merged = append(merged, cur...) + *dest = merged +} diff --git a/plugins/prompts/plugin_test.go b/plugins/prompts/plugin_test.go new file mode 100644 index 0000000000..554174705c --- /dev/null +++ b/plugins/prompts/plugin_test.go @@ -0,0 +1,1065 @@ +package prompts + +import ( + "context" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + tables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================ +// InitWithResolver +// ============================================================ + +func TestInitWithResolver_NilStore(t *testing.T) { + _, err := InitWithResolver(context.Background(), nil, &staticResolver{}, NewMockLogger()) + require.Error(t, err, "expected error for nil store") +} + +func TestInitWithResolver_NilResolverFallsBackToHeader(t *testing.T) { + ms := &mockStore{} + p, err := InitWithResolver(context.Background(), ms, nil, NewMockLogger()) + require.NoError(t, err) + require.NotNil(t, p) + _, ok := p.resolver.(*headerResolver) + assert.True(t, ok, "expected headerResolver, got %T", p.resolver) +} + +// ============================================================ +// loadCache +// ============================================================ + +func TestLoadCache_EmptyStore(t *testing.T) { + p := newPluginWithStore(&mockStore{}) + require.NoError(t, p.loadCache(context.Background())) + assert.Empty(t, p.promptsByID) + assert.Empty(t, p.versionsByPromptAndNumber) +} + +func TestLoadCache_PopulatesMaps(t *testing.T) { + v1 := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "Hello")) + v2 := makeVersion(2, "p2", true) + p1 := makePrompt("p1", &v1) + p2 := makePrompt("p2", &v2) + + p := newPluginWithStore(&mockStore{ + prompts: []tables.TablePrompt{p1, p2}, + versions: []tables.TablePromptVersion{v1, v2}, + }) + + require.NoError(t, p.loadCache(context.Background())) + assert.Len(t, p.promptsByID, 2) + assert.Len(t, p.versionsByPromptAndNumber, 2) + assert.NotNil(t, p.promptsByID["p1"]) + assert.NotNil(t, p.versionsByPromptAndNumber["p1"][1]) +} + +func TestLoadCache_GetPromptsError(t *testing.T) { + p := newPluginWithStore(&mockStore{err: errTest("boom")}) + err := p.loadCache(context.Background()) + require.Error(t, err) +} + +func TestLoadCache_GetVersionsError(t *testing.T) { + p := newPluginWithStore(&versionsErrStore{ + prompts: []tables.TablePrompt{makePrompt("p1", nil)}, + err: errTest("versions boom"), + }) + err := p.loadCache(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "versions boom") +} + +// ============================================================ +// PreLLMHook +// ============================================================ + +func TestPreLLMHook_NoPromptID(t *testing.T) { + p := newTestPlugin(&staticResolver{promptID: ""}, nil, nil) + out, sc, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + assert.Nil(t, sc) + assert.Len(t, out.ChatRequest.Input, 1) +} + +func TestPreLLMHook_PromptNotFound(t *testing.T) { + log := NewMockLogger() + p := newTestPluginWithLogger(&staticResolver{promptID: "missing"}, nil, nil, log) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected a warning for unknown prompt") +} + +func TestPreLLMHook_UseLatestVersion(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "Be helpful"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2, "expected system prompt + user message") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "Be helpful", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role) + assert.Equal(t, "hello", msgText(out.ChatRequest.Input[1])) +} + +func TestPreLLMHook_UseSpecificVersion(t *testing.T) { + vLatest := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "latest system prompt"), + ) + vOld := makeVersion(2, "p1", false, + versionMsg(schemas.ChatMessageRoleSystem, "old system prompt"), + ) + prompt := makePrompt("p1", &vLatest) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionNumber: 2, versionSpecified: true}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + + // Must use vOld, not vLatest. + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "old system prompt", msgText(out.ChatRequest.Input[0])) +} + +func TestPreLLMHook_VersionNotFound(t *testing.T) { + v := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "hello")) + prompt := makePrompt("p1", &v) + log := NewMockLogger() + + p := newTestPluginWithLogger( + &staticResolver{promptID: "p1", versionNumber: 99, versionSpecified: true}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for missing version") +} + +func TestPreLLMHook_VersionBelongsToDifferentPrompt(t *testing.T) { + v := makeVersion(1, "p2", true, versionMsg(schemas.ChatMessageRoleSystem, "wrong")) + prompt := makePrompt("p1", nil) + log := NewMockLogger() + + p := newTestPluginWithLogger( + &staticResolver{promptID: "p1", versionNumber: 1, versionSpecified: true}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p2": {1: &v}}, + log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for version/prompt mismatch") +} + +func TestPreLLMHook_NoLatestVersion(t *testing.T) { + prompt := makePrompt("p1", nil) // LatestVersion is nil + log := NewMockLogger() + + p := newTestPluginWithLogger( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + nil, + log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for missing latest version") +} + +func TestPreLLMHook_EmptyTemplate(t *testing.T) { + v := makeVersion(1, "p1", true) // no messages + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1) +} + +func TestPreLLMHook_MultipleTemplateMessages(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "sys prompt"), + versionMsg(schemas.ChatMessageRoleUser, "example input"), + versionMsg(schemas.ChatMessageRoleAssistant, "example output"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("actual question"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 4, "expected 3 template messages + 1 original") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "sys prompt", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role) + assert.Equal(t, "example input", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[2].Role) + assert.Equal(t, "example output", msgText(out.ChatRequest.Input[2])) + + // Original user message must be last, content preserved. + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role) + assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[3])) +} + +func TestPreLLMHook_ResolverError(t *testing.T) { + log := NewMockLogger() + p := newTestPluginWithLogger( + &staticResolver{err: errTest("resolver failed")}, + nil, nil, log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err, "PreLLMHook must not propagate resolver errors") + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for resolver error") +} + +func TestPreLLMHook_MessageJSON_FallbackPath(t *testing.T) { + // Messages where Message ([]byte) is nil but MessageJSON is set β€” the fallback + // branch in chatMessagesFromVersionMessages. This mirrors rows loaded from + // an older DB schema before AfterFind was established. + v := makeVersion(1, "p1", true, + versionMsgViaJSON(schemas.ChatMessageRoleSystem, "from json field"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "from json field", msgText(out.ChatRequest.Input[0])) +} + +func TestPreLLMHook_ResponsesRequest(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "be concise"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + userRole := schemas.ResponsesMessageRoleType("user") + req := &schemas.BifrostRequest{ + ResponsesRequest: &schemas.BifrostResponsesRequest{ + Input: []schemas.ResponsesMessage{{Role: &userRole}}, + }, + } + + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + // Template message(s) prepended before the original user input. + assert.Greater(t, len(out.ResponsesRequest.Input), 1, "expected template prepended before user message") + // Original user message must still be last. + last := out.ResponsesRequest.Input[len(out.ResponsesRequest.Input)-1] + assert.Equal(t, schemas.ResponsesMessageRoleType("user"), *last.Role) +} + +// TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg verifies that when the +// prompt template contains a system message and the incoming request also starts +// with a system message, both system messages are forwarded to the model β€” +// the plugin's only job is prepending, not de-duplicating or filtering roles. +func TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "prompt system"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + // Incoming request already has its own system message before the user turn. + out, _, err := p.PreLLMHook(bfCtx(), chatRequest( + systemMsg("user-side system context"), + userMsg("actual question"), + )) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 3, "expected prompt system + user system + user message") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "prompt system", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role) + assert.Equal(t, "user-side system context", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) + assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[2])) +} + +// TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage verifies that when +// the prompt template contains a full tool-call turn (system β†’ assistant with +// tool_calls β†’ tool result) and the user sends a new message, the entire +// template is prepended and all fields (ToolCalls, ToolCallID) are preserved. +func TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage(t *testing.T) { + const callID = "call_abc123" + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "you are a weather bot"), + versionMsgWithToolCall(callID, "get_weather", `{"city":"Paris"}`), + versionMsgToolResult(callID, "Sunny, 22Β°C"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("what about tomorrow?"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 4, "expected system + assistant(tool_calls) + tool_result + user") + + // System message from prompt. + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "you are a weather bot", msgText(out.ChatRequest.Input[0])) + + // Assistant message with tool_calls must carry its ToolCalls slice. + assistantMsg := out.ChatRequest.Input[1] + assert.Equal(t, schemas.ChatMessageRoleAssistant, assistantMsg.Role) + require.NotNil(t, assistantMsg.ChatAssistantMessage, "ChatAssistantMessage must be present") + require.Len(t, assistantMsg.ChatAssistantMessage.ToolCalls, 1) + tc := assistantMsg.ChatAssistantMessage.ToolCalls[0] + require.NotNil(t, tc.ID) + assert.Equal(t, callID, *tc.ID) + require.NotNil(t, tc.Function.Name) + assert.Equal(t, "get_weather", *tc.Function.Name) + assert.Equal(t, `{"city":"Paris"}`, tc.Function.Arguments) + + // Tool result message must carry the ToolCallID. + toolResultMsg := out.ChatRequest.Input[2] + assert.Equal(t, schemas.ChatMessageRoleTool, toolResultMsg.Role) + assert.Equal(t, "Sunny, 22Β°C", msgText(toolResultMsg)) + require.NotNil(t, toolResultMsg.ChatToolMessage, "ChatToolMessage must be present") + require.NotNil(t, toolResultMsg.ChatToolMessage.ToolCallID) + assert.Equal(t, callID, *toolResultMsg.ChatToolMessage.ToolCallID) + + // Original user message is last. + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role) + assert.Equal(t, "what about tomorrow?", msgText(out.ChatRequest.Input[3])) +} + +// TestPreLLMHook_MultipleSystemMessages_InPromptTemplate verifies that a prompt +// template may itself contain multiple system messages and all of them are +// prepended before the user's input in the original order. +func TestPreLLMHook_MultipleSystemMessages_InPromptTemplate(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "first system"), + versionMsg(schemas.ChatMessageRoleSystem, "second system"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 3, "expected 2 system messages + user message") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "first system", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role) + assert.Equal(t, "second system", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) + assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2])) +} + +// ============================================================ +// HTTPTransportPreHook +// ============================================================ + +func TestHTTPTransportPreHook_NilRequest(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + resp, err := p.HTTPTransportPreHook(bfCtx(), nil) + assert.NoError(t, err) + assert.Nil(t, resp) +} + +func TestHTTPTransportPreHook_SetsPromptID(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptIDHeader: "my-prompt"}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptIDKey).(string) + assert.Equal(t, "my-prompt", got) +} + +func TestHTTPTransportPreHook_SetsVersionID(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptVersionHeader: "42"}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptVersionKey).(string) + assert.Equal(t, "42", got) +} + +func TestHTTPTransportPreHook_TrimsWhitespace(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptIDHeader: " padded "}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptIDKey).(string) + assert.Equal(t, "padded", got) +} + +func TestHTTPTransportPreHook_WhitespaceOnlyNotSet(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptIDHeader: " "}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + assert.Nil(t, ctx.Value(PromptIDKey), "whitespace-only header must not be stored in context") +} + +func TestHTTPTransportPreHook_CaseInsensitiveHeaders(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + // "Bf-Prompt-Id" is a title-case variant of the canonical "bf-prompt-id". + req := &schemas.HTTPRequest{ + Headers: map[string]string{"Bf-Prompt-Id": "upper-case"}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptIDKey).(string) + assert.Equal(t, "upper-case", got) +} + +// ============================================================ +// chatMessageFromStoredJSON +// ============================================================ + +func TestChatMessageFromStoredJSON(t *testing.T) { + systemText := "you are helpful" + directMsg := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: &schemas.ChatMessageContent{ContentStr: &systemText}, + } + directJSON, _ := json.Marshal(directMsg) + envelopeJSON := []byte(`{"payload":` + string(directJSON) + `}`) + + tests := []struct { + name string + input []byte + wantErr bool + wantRole schemas.ChatMessageRole + wantText string + }{ + { + name: "direct format", + input: directJSON, + wantRole: schemas.ChatMessageRoleSystem, + wantText: systemText, + }, + { + name: "envelope format", + input: envelopeJSON, + wantRole: schemas.ChatMessageRoleSystem, + wantText: systemText, + }, + { + // UI format for assistant messages: originalType=completion_result, + // payload is a BifrostChatResponse; message lives at choices[0].message. + name: "completion_result envelope (UI assistant format)", + input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}]}}`), + wantRole: schemas.ChatMessageRoleAssistant, + wantText: "hi there", + }, + { + // completion_result with no choices falls through to direct ChatMessage parse. + name: "completion_result envelope with empty choices", + input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[]}}`), + wantErr: false, + wantRole: "", + wantText: "", + }, + { + name: "empty bytes", + input: []byte(""), + wantErr: true, + }, + { + name: "null bytes", + input: []byte("null"), + wantErr: true, + }, + { + name: "whitespace only", + input: []byte(" "), + wantErr: true, + }, + { + name: "malformed envelope payload", + input: []byte(`{"payload":"not-a-chat-message"}`), + wantErr: true, + }, + { + // {"payload":null} β€” envelope path is skipped (payload is "null"), + // falls through to direct decode of the outer object as ChatMessage. + // schemas.Unmarshal succeeds on an unknown-field object β†’ empty ChatMessage, no error. + name: "envelope with null payload falls through to direct decode", + input: []byte(`{"payload":null}`), + wantErr: false, + wantRole: "", + wantText: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := convertVersionMessagesToChatMessages(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantRole, got.Role) + assert.Equal(t, tt.wantText, msgText(got)) + }) + } +} + +func TestChatMessageFromStoredJSON_EnvelopeWithEmptyStringPayload(t *testing.T) { + // {"payload":""} β€” the payload field is a non-null, non-empty JSON string `""`. + // The envelope path attempts to unmarshal `""` (a JSON string literal) into + // schemas.ChatMessage (a struct), which fails. The error is returned directly; + // there is no further fallback. + input := []byte(`{"payload":""}`) + _, err := convertVersionMessagesToChatMessages(input) + require.Error(t, err) + assert.Contains(t, err.Error(), "decoding prompt message envelope payload") +} + +// ============================================================ +// parsePromptVersionNumber +// ============================================================ + +func TestParsePromptVersionNumber(t *testing.T) { + type want struct { + num int + specified bool + wantErr bool + } + + tests := []struct { + name string + value any // stored in context; nil means don't set + want want + }{ + {name: "nil β€” not specified", value: nil, want: want{0, false, false}}, + {name: "string valid", value: "99", want: want{99, true, false}}, + {name: "string empty", value: "", want: want{0, false, false}}, + {name: "string whitespace", value: " ", want: want{0, false, false}}, + {name: "string invalid", value: "abc", want: want{0, true, true}}, + {name: "unknown type", value: struct{}{}, want: want{0, false, false}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := bfCtx() + if tt.value != nil { + ctx.SetValue(PromptVersionKey, tt.value) + } + + num, specified, err := parsePromptVersionNumber(ctx) + + if tt.want.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want.specified, specified) + assert.Equal(t, tt.want.num, num) + }) + } +} + +// ============================================================ +// mergeChatMessages +// ============================================================ + +func TestMergeChatMessages(t *testing.T) { + t.Run("prepends prefix before existing messages", func(t *testing.T) { + dest := []schemas.ChatMessage{userMsg("original")} + prefix := []schemas.ChatMessage{systemMsg("system")} + mergeChatMessages(&dest, prefix) + + require.Len(t, dest, 2) + assert.Equal(t, schemas.ChatMessageRoleSystem, dest[0].Role) + assert.Equal(t, "system", msgText(dest[0])) + assert.Equal(t, schemas.ChatMessageRoleUser, dest[1].Role) + assert.Equal(t, "original", msgText(dest[1])) + }) + + t.Run("nil dest is a no-op", func(t *testing.T) { + // Must not panic. + mergeChatMessages(nil, []schemas.ChatMessage{systemMsg("x")}) + }) + + t.Run("empty prefix is a no-op", func(t *testing.T) { + dest := []schemas.ChatMessage{userMsg("only")} + mergeChatMessages(&dest, nil) + assert.Len(t, dest, 1) + assert.Equal(t, "only", msgText(dest[0])) + }) +} + +// ============================================================ +// chatMessagesFromVersionMessages +// ============================================================ + +func TestChatMessagesFromVersionMessages_SingleMessage(t *testing.T) { + msg := versionMsg(schemas.ChatMessageRoleUser, "hello") + out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg}) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, schemas.ChatMessageRoleUser, out[0].Role) + assert.Equal(t, "hello", msgText(out[0])) +} + +func TestChatMessagesFromVersionMessages_MessageJSONFallback(t *testing.T) { + // Row has no Message bytes but has MessageJSON β€” exercises the fallback branch. + msg := versionMsgViaJSON(schemas.ChatMessageRoleAssistant, "assistant reply") + out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg}) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, schemas.ChatMessageRoleAssistant, out[0].Role) + assert.Equal(t, "assistant reply", msgText(out[0])) +} + +func TestChatMessagesFromVersionMessages_PreservesOrder(t *testing.T) { + msgs := []tables.TablePromptVersionMessage{ + versionMsg(schemas.ChatMessageRoleSystem, "first"), + versionMsg(schemas.ChatMessageRoleUser, "second"), + versionMsg(schemas.ChatMessageRoleAssistant, "third"), + } + out, err := chatMessagesFromVersionMessages(msgs) + require.NoError(t, err) + require.Len(t, out, 3) + assert.Equal(t, schemas.ChatMessageRoleSystem, out[0].Role) + assert.Equal(t, "first", msgText(out[0])) + assert.Equal(t, schemas.ChatMessageRoleUser, out[1].Role) + assert.Equal(t, "second", msgText(out[1])) + assert.Equal(t, schemas.ChatMessageRoleAssistant, out[2].Role) + assert.Equal(t, "third", msgText(out[2])) +} + +func TestChatMessagesFromVersionMessages_InvalidJSON(t *testing.T) { + bad := tables.TablePromptVersionMessage{Message: []byte(`not-json`)} + _, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{bad}) + require.Error(t, err) +} + +// ============================================================ +// loadCache + PreLLMHook integration (store β†’ cache β†’ injection) +// ============================================================ + +// ============================================================ +// includesStreamInModelParams +// ============================================================ + +func TestIncludesStreamInModelParams(t *testing.T) { + tests := []struct { + name string + params tables.ModelParams + want bool + }{ + {"bool true", tables.ModelParams{"stream": true}, true}, + {"bool false", tables.ModelParams{"stream": false}, false}, + {"string true", tables.ModelParams{"stream": "true"}, true}, + {"string yes", tables.ModelParams{"stream": "yes"}, true}, + {"string 1", tables.ModelParams{"stream": "1"}, true}, + {"string false", tables.ModelParams{"stream": "false"}, false}, + {"string 0", tables.ModelParams{"stream": "0"}, false}, + {"absent key", tables.ModelParams{"temperature": 0.7}, true}, + {"empty params", tables.ModelParams{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, includesStreamInModelParams(tt.params)) + }) + } +} + +// ============================================================ +// HTTPTransportPreHook β€” stream routing via version ModelParams +// ============================================================ + +// TestHTTPTransportPreHook_StreamTrue_SetsStreamContext verifies that when the +// resolved version has stream:true in ModelParams, the hook marks the bifrost +// context so that the inference handler opens an SSE response. +func TestHTTPTransportPreHook_StreamTrue_SetsStreamContext(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": true}) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + streamVal, _ := ctx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool) + assert.True(t, streamVal, "expected BifrostContextKeyPromptStreamRequest=true when version has stream:true") +} + +// TestHTTPTransportPreHook_StreamFalse_NoStreamContext verifies that stream:false +// in ModelParams does NOT set the stream context key. +func TestHTTPTransportPreHook_StreamFalse_NoStreamContext(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": false}) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + assert.Nil(t, ctx.Value(schemas.BifrostContextKeyPromptStreamRequest), + "expected BifrostContextKeyPromptStreamRequest not set when version has stream:false") +} + +// TestHTTPTransportPreHook_NoStreamParam_NoStreamContext verifies that when no +// "stream" key is present in ModelParams, the stream context key is not set. +func TestHTTPTransportPreHook_NoStreamParam_NoStreamContext(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"temperature": float64(0.7)}) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + assert.Equal(t, true, ctx.Value(schemas.BifrostContextKeyPromptStreamRequest), + "expected BifrostContextKeyPromptStreamRequest to default to true when no stream key in params") +} + +// TestHTTPTransportPreHook_SpecificVersion_StreamTrue_SetsStreamContext verifies +// that when a specific (non-latest) version is requested via header and that +// version has stream:true, the stream context key is set β€” even if the latest +// version has stream:false. +func TestHTTPTransportPreHook_SpecificVersion_StreamTrue_SetsStreamContext(t *testing.T) { + vLatest := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": false}) + vOld := makeVersionWithParams(2, "p1", false, tables.ModelParams{"stream": true}) + prompt := makePrompt("p1", &vLatest) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{ + PromptIDHeader: "p1", + PromptVersionHeader: "2", + }, + } + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + streamVal, _ := ctx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool) + assert.True(t, streamVal, "expected stream=true from explicitly requested version with stream:true") +} + +// ============================================================ +// PreLLMHook β€” model params merge and override +// ============================================================ + +// TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams verifies that when +// the request carries no Params at all, the version's ModelParams become the +// effective parameters on the outgoing request. +func TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{"temperature": float64(0.7)}, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + require.NotNil(t, out.ChatRequest.Params, "expected Params to be set from version ModelParams") + require.NotNil(t, out.ChatRequest.Params.Temperature) + assert.InDelta(t, 0.7, *out.ChatRequest.Params.Temperature, 0.001) +} + +// TestPreLLMHook_RequestParamsOverrideVersionParams verifies that when both the +// version and the request specify the same parameter, the request value wins. +func TestPreLLMHook_RequestParamsOverrideVersionParams(t *testing.T) { + reqTemp := 0.9 + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{"temperature": float64(0.3)}, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello")) + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + require.NotNil(t, out.ChatRequest.Params) + require.NotNil(t, out.ChatRequest.Params.Temperature) + assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001, + "request temperature must override version default temperature") +} + +// TestPreLLMHook_RequestParamsPartialOverride verifies the mixed case: version +// sets temperature and max_completion_tokens; request overrides only temperature. +// The version's max_completion_tokens must still be applied. +func TestPreLLMHook_RequestParamsPartialOverride(t *testing.T) { + reqTemp := 0.9 + maxTokens := 200 + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{ + "temperature": float64(0.3), + "max_completion_tokens": float64(maxTokens), + }, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello")) + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + require.NotNil(t, out.ChatRequest.Params) + require.NotNil(t, out.ChatRequest.Params.Temperature) + assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001, + "request temperature must override version temperature") + require.NotNil(t, out.ChatRequest.Params.MaxCompletionTokens, + "version max_completion_tokens must be applied when request does not override it") + assert.Equal(t, maxTokens, *out.ChatRequest.Params.MaxCompletionTokens) +} + +// ============================================================ +// PreLLMHook β€” model field preservation +// ============================================================ + +// TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel verifies that +// a "model" key inside a version's ModelParams (which the UI may store alongside +// temperature etc.) does NOT replace the model field on the outgoing +// BifrostChatRequest. The model chosen by the caller must always win. +func TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{ + "model": "openai/gpt-4o", + "temperature": float64(0.5), + }, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + req := chatRequestWithModel("openai/gpt-3.5-turbo", userMsg("hi")) + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + assert.Equal(t, "openai/gpt-3.5-turbo", out.ChatRequest.Model, + "request model must not be overridden by model stored in version ModelParams") +} + +// ============================================================ +// loadCache + PreLLMHook integration (store β†’ cache β†’ injection) +// ============================================================ + +// TestLoadCacheAndPreLLMHook_EndToEnd verifies the full pipeline: +// mockStore returns TablePrompt/TablePromptVersion structs β†’ loadCache populates +// the in-memory maps β†’ PreLLMHook injects the template messages correctly. +// This catches any mismatch between how loadCache builds the maps and how +// PreLLMHook reads them (e.g. pointer aliasing, LatestVersion linking). +func TestLoadCacheAndPreLLMHook_EndToEnd(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "end-to-end system"), + ) + prompt := makePrompt("p1", &v) + + ms := &mockStore{ + prompts: []tables.TablePrompt{prompt}, + versions: []tables.TablePromptVersion{v}, + } + + p := newPluginWithStore(ms) + require.NoError(t, p.loadCache(context.Background())) + + p.resolver = &staticResolver{promptID: "p1", versionSpecified: false} + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("user msg"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "end-to-end system", msgText(out.ChatRequest.Input[0])) + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role) + assert.Equal(t, "user msg", msgText(out.ChatRequest.Input[1])) +} + +// TestLoadCacheAndPreLLMHook_SpecificVersion exercises the loadCache β†’ PreLLMHook +// path for a version lookup by ID (not just the LatestVersion pointer). +func TestLoadCacheAndPreLLMHook_SpecificVersion(t *testing.T) { + vOld := makeVersion(2, "p1", false, + versionMsg(schemas.ChatMessageRoleSystem, "old via store"), + ) + vLatest := makeVersion(3, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "latest via store"), + ) + prompt := makePrompt("p1", &vLatest) + + ms := &mockStore{ + prompts: []tables.TablePrompt{prompt}, + versions: []tables.TablePromptVersion{vOld, vLatest}, + } + + p := newPluginWithStore(ms) + require.NoError(t, p.loadCache(context.Background())) + + p.resolver = &staticResolver{promptID: "p1", versionNumber: 2, versionSpecified: true} + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("question"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + assert.Equal(t, "old via store", msgText(out.ChatRequest.Input[0])) +} + +// TestPreLLMHook_AssistantMessage_UIFormat verifies that assistant messages stored +// in the Bifrost UI's completion_result format (payload.choices[0].message) are +// correctly extracted and prepended to the request. +func TestPreLLMHook_AssistantMessage_UIFormat(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "be helpful"), + versionMsgAssistantUIFormat("sure, how can I help?"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 3, "expected system + assistant + user") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "be helpful", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[1].Role) + assert.Equal(t, "sure, how can I help?", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) + assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2])) +} diff --git a/plugins/prompts/version b/plugins/prompts/version new file mode 100644 index 0000000000..7f207341d5 --- /dev/null +++ b/plugins/prompts/version @@ -0,0 +1 @@ +1.0.1 \ No newline at end of file diff --git a/plugins/semanticcache/changelog.md b/plugins/semanticcache/changelog.md index e69de29bb2..9d094203da 100644 --- a/plugins/semanticcache/changelog.md +++ b/plugins/semanticcache/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.1 and framework to v1.3.1 diff --git a/plugins/semanticcache/go.mod b/plugins/semanticcache/go.mod index 23fa139b8c..f81ec3e1d6 100644 --- a/plugins/semanticcache/go.mod +++ b/plugins/semanticcache/go.mod @@ -5,9 +5,9 @@ go 1.26.1 require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/google/uuid v1.6.0 - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 - github.com/maximhq/bifrost/plugins/mocker v1.4.17 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 + github.com/maximhq/bifrost/plugins/mocker v1.5.1 ) require ( diff --git a/plugins/semanticcache/go.sum b/plugins/semanticcache/go.sum index 2d8263170f..53529fe4a1 100644 --- a/plugins/semanticcache/go.sum +++ b/plugins/semanticcache/go.sum @@ -195,12 +195,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= -github.com/maximhq/bifrost/plugins/mocker v1.4.17 h1:CEItx77k22fS/N5K8/dCQpse88yfbgzVebQWJXOH4NY= -github.com/maximhq/bifrost/plugins/mocker v1.4.17/go.mod h1:RrA/XyRkggxYiK10k6D6r9VjfmRyiGBIW92ZvhWAtUw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= +github.com/maximhq/bifrost/plugins/mocker v1.5.1 h1:tXB8WPH9J7MURk45PNjx0hh9TeZzyBXqAYFaKUWdQtM= +github.com/maximhq/bifrost/plugins/mocker v1.5.1/go.mod h1:qbjCfskG6jN23rtrLYmaxFBvA5CzOTJ67UIEuyFkO90= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= github.com/oapi-codegen/runtime v1.1.1/go.mod h1:SK9X900oXmPWilYR5/WKPzt3Kqxn/uS/+lbpREv+eCg= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= diff --git a/plugins/semanticcache/plugin_cache_type_test.go b/plugins/semanticcache/plugin_cache_type_test.go index b97a09715f..13979df38f 100644 --- a/plugins/semanticcache/plugin_cache_type_test.go +++ b/plugins/semanticcache/plugin_cache_type_test.go @@ -390,9 +390,10 @@ func TestDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t *testing }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-5.2", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-5.2", + ResolvedModelUsed: "gpt-5.2", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -417,8 +418,11 @@ func TestDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t *testing if extraFields.Provider != schemas.OpenAI { t.Fatalf("expected cached provider %q, got %q", schemas.OpenAI, extraFields.Provider) } - if extraFields.ModelRequested != "gpt-5.2" { - t.Fatalf("expected cached model_requested %q, got %q", "gpt-5.2", extraFields.ModelRequested) + if extraFields.OriginalModelRequested != "gpt-5.2" { + t.Fatalf("expected OriginalModelRequested %q, got %q", "gpt-5.2", extraFields.OriginalModelRequested) + } + if extraFields.ResolvedModelUsed != "gpt-5.2" { + t.Fatalf("expected ResolvedModelUsed %q, got %q", "gpt-5.2", extraFields.ResolvedModelUsed) } if extraFields.CacheDebug == nil { t.Fatal("expected cache_debug on cache hit") @@ -491,10 +495,11 @@ func TestStreamingDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-5.2", - RequestType: schemas.ChatCompletionStreamRequest, - ChunkIndex: chunk.chunkIndex, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-5.2", + ResolvedModelUsed: "gpt-5.2", + RequestType: schemas.ChatCompletionStreamRequest, + ChunkIndex: chunk.chunkIndex, }, }, } @@ -526,8 +531,11 @@ func TestStreamingDirectCacheHitPreservesCachedProviderMetadataAcrossProviders(t if extraFields.Provider != schemas.OpenAI { t.Fatalf("expected cached provider %q on chunk %d, got %q", schemas.OpenAI, chunkCount, extraFields.Provider) } - if extraFields.ModelRequested != "gpt-5.2" { - t.Fatalf("expected cached model_requested %q on chunk %d, got %q", "gpt-5.2", chunkCount, extraFields.ModelRequested) + if extraFields.OriginalModelRequested != "gpt-5.2" { + t.Fatalf("expected OriginalModelRequested %q on chunk %d, got %q", "gpt-5.2", chunkCount, extraFields.OriginalModelRequested) + } + if extraFields.ResolvedModelUsed != "gpt-5.2" { + t.Fatalf("expected ResolvedModelUsed %q on chunk %d, got %q", "gpt-5.2", chunkCount, extraFields.ResolvedModelUsed) } if chunkCount == len(chunks)-1 { if extraFields.CacheDebug == nil || !extraFields.CacheDebug.CacheHit { diff --git a/plugins/semanticcache/plugin_core_test.go b/plugins/semanticcache/plugin_core_test.go index 6e29d002c1..822fc1f645 100644 --- a/plugins/semanticcache/plugin_core_test.go +++ b/plugins/semanticcache/plugin_core_test.go @@ -390,7 +390,7 @@ func TestCacheConfiguration(t *testing.T) { Dimension: 1536, Threshold: 0.95, // Very high threshold Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, }, }, expectedBehavior: "strict_matching", @@ -403,7 +403,7 @@ func TestCacheConfiguration(t *testing.T) { Dimension: 1536, Threshold: 0.1, // Very low threshold Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, }, }, expectedBehavior: "loose_matching", @@ -417,7 +417,7 @@ func TestCacheConfiguration(t *testing.T) { Threshold: 0.8, TTL: 1 * time.Hour, // Custom TTL Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: schemas.WhiteList{"*"}, Weight: 1.0}, }, }, expectedBehavior: "custom_ttl", @@ -550,7 +550,7 @@ func TestInvalidProviderRejection(t *testing.T) { Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.TEST_API_KEY"), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, @@ -587,7 +587,7 @@ func TestValidProviderAccepted(t *testing.T) { Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, diff --git a/plugins/semanticcache/plugin_image_generation_test.go b/plugins/semanticcache/plugin_image_generation_test.go index ce70dd1698..f50f3c5c9b 100644 --- a/plugins/semanticcache/plugin_image_generation_test.go +++ b/plugins/semanticcache/plugin_image_generation_test.go @@ -129,7 +129,7 @@ func TestImageGenerationSemanticSearch(t *testing.T) { Dimension: 1536, Threshold: 0.5, Keys: []schemas.Key{ - {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{}, Weight: 1.0}, + {Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), Models: []string{"*"}, Weight: 1.0}, }, } setup := NewTestSetupWithConfig(t, config) diff --git a/plugins/semanticcache/plugin_integration_test.go b/plugins/semanticcache/plugin_integration_test.go index 92e3cd16e2..58ab9d04c3 100644 --- a/plugins/semanticcache/plugin_integration_test.go +++ b/plugins/semanticcache/plugin_integration_test.go @@ -18,7 +18,7 @@ func TestSemanticCacheBasicFlow(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") - + // Test request request := &schemas.BifrostRequest{ RequestType: schemas.ChatCompletionRequest, @@ -75,9 +75,9 @@ func TestSemanticCacheBasicFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -213,9 +213,9 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -309,7 +309,7 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { setup := NewTestSetup(t) defer setup.Cleanup() - ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") request := &schemas.BifrostRequest{ @@ -375,10 +375,10 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionStreamRequest, - ChunkIndex: i, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionStreamRequest, + ChunkIndex: i, }, }, } @@ -524,9 +524,9 @@ func TestSemanticCache_CustomTTLHandling(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -547,7 +547,7 @@ func TestSemanticCache_CustomThresholdHandling(t *testing.T) { defer setup.Cleanup() // Configure plugin with custom threshold key - ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) ctx.SetValue(CacheKey, "test-cache-enabled") ctx.SetValue(CacheThresholdKey, 0.95) // Very high threshold @@ -635,9 +635,9 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } diff --git a/plugins/semanticcache/plugin_vectorstore_test.go b/plugins/semanticcache/plugin_vectorstore_test.go index 5ac3029523..5e390bbe80 100644 --- a/plugins/semanticcache/plugin_vectorstore_test.go +++ b/plugins/semanticcache/plugin_vectorstore_test.go @@ -58,7 +58,7 @@ func getDefaultTestConfig() *Config { Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, @@ -132,9 +132,9 @@ func TestSemanticCache_AllVectorStores_BasicFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } @@ -331,9 +331,9 @@ func TestSemanticCache_AllVectorStores_ParameterFiltering(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ModelRequested: "gpt-4o-mini", - RequestType: schemas.ChatCompletionRequest, + Provider: schemas.OpenAI, + OriginalModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, }, } diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index 5bda10e858..1aad7d7ada 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -122,7 +122,7 @@ func (baseAccount *BaseAccount) GetKeysForProvider(ctx context.Context, provider return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, // Empty models array means it supports ALL models + Models: schemas.WhiteList{"*"}, // "*" means allow all models Weight: 1.0, }, }, nil @@ -374,7 +374,7 @@ func NewTestSetup(t *testing.T) *TestSetup { Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, @@ -651,7 +651,7 @@ func CreateTestSetupWithConversationThreshold(t *testing.T, threshold int) *Test Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, }, }, @@ -672,7 +672,7 @@ func CreateTestSetupWithExcludeSystemPrompt(t *testing.T, excludeSystem bool) *T Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, }, }, @@ -694,7 +694,7 @@ func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, e Keys: []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, }, }, diff --git a/plugins/semanticcache/version b/plugins/semanticcache/version index f32f94b9f6..8e03717dca 100644 --- a/plugins/semanticcache/version +++ b/plugins/semanticcache/version @@ -1 +1 @@ -1.4.34 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/plugins/telemetry/changelog.md b/plugins/telemetry/changelog.md index e69de29bb2..9d094203da 100644 --- a/plugins/telemetry/changelog.md +++ b/plugins/telemetry/changelog.md @@ -0,0 +1 @@ +- chore: upgraded core to v1.5.1 and framework to v1.3.1 diff --git a/plugins/telemetry/go.mod b/plugins/telemetry/go.mod index 707bfbe1fd..4609a1d8fb 100644 --- a/plugins/telemetry/go.mod +++ b/plugins/telemetry/go.mod @@ -3,8 +3,8 @@ module github.com/maximhq/bifrost/plugins/telemetry go 1.26.1 require ( - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 + github.com/maximhq/bifrost/core v1.5.1 + github.com/maximhq/bifrost/framework v1.3.1 github.com/prometheus/client_golang v1.23.2 github.com/valyala/fasthttp v1.68.0 ) diff --git a/plugins/telemetry/go.sum b/plugins/telemetry/go.sum index 76e4d1c139..b0e86f039f 100644 --- a/plugins/telemetry/go.sum +++ b/plugins/telemetry/go.sum @@ -195,10 +195,10 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= +github.com/maximhq/bifrost/core v1.5.1 h1:iJoVnI4q0CpNylBqXLVaZUc0qgJhd8j8Xa2vtNfZ8bc= +github.com/maximhq/bifrost/core v1.5.1/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= +github.com/maximhq/bifrost/framework v1.3.1 h1:HpKD0JigkxsR6+jI3DDxAm9AKsO241E3sj2BpxG82Xs= +github.com/maximhq/bifrost/framework v1.3.1/go.mod h1:M+MDjP4cRZMinI2qk0DHtTp9ayFWaoQ2Ye+ikmyhGYQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/oapi-codegen/runtime v1.1.1 h1:EXLHh0DXIJnWhdRPN2w4MXAzFyE4CskzhNLUmtpMYro= diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index 3c1ece58bf..7025b902e7 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -136,6 +136,7 @@ func Init(config *Config, pricingManager *modelcatalog.ModelCatalog, logger sche defaultBifrostLabels := []string{ "provider", "model", + "alias", "method", "virtual_key_id", "virtual_key_name", @@ -359,7 +360,17 @@ func (p *PrometheusPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas. // - Request latency // - Total request count func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - requestType, provider, model := bifrost.GetResponseFields(result, bifrostErr) + requestType, provider, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) + + // Determine effective model label and alias label (mirrors applyModelAlias logic in logging) + model := originalModel + alias := "" + if resolvedModel != "" { + model = resolvedModel + if resolvedModel != originalModel { + alias = originalModel + } + } startTime, ok := ctx.Value(startTimeKey).(time.Time) if !ok { @@ -393,6 +404,7 @@ func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche labelValues := map[string]string{ "provider": string(provider), "model": model, + "alias": alias, "method": string(requestType), "virtual_key_id": virtualKeyID, "virtual_key_name": virtualKeyName, @@ -425,6 +437,8 @@ func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche streamEndIndicatorValue := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator) isFinalChunk, hasFinalChunkIndicator := streamEndIndicatorValue.(bool) + pricingScopes := modelcatalog.PricingLookupScopesFromContext(ctx, string(provider)) + // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread go func() { // For streaming requests, handle per-token metrics for intermediate chunks @@ -447,7 +461,7 @@ func (p *PrometheusPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *sche cost := 0.0 if p.pricingManager != nil && result != nil { - cost = p.pricingManager.CalculateCost(result) + cost = p.pricingManager.CalculateCost(result, pricingScopes) } p.UpstreamRequestsTotal.WithLabelValues(promLabelValues...).Inc() diff --git a/plugins/telemetry/version b/plugins/telemetry/version index 62f0c2cadb..8e03717dca 100644 --- a/plugins/telemetry/version +++ b/plugins/telemetry/version @@ -1 +1 @@ -1.4.36 \ No newline at end of file +1.5.1 \ No newline at end of file diff --git a/tests/e2e/api/collections/bifrost-v1-vk-auth.postman_collection.json b/tests/e2e/api/collections/bifrost-v1-vk-auth.postman_collection.json index c4f494c37e..0404a316d4 100644 --- a/tests/e2e/api/collections/bifrost-v1-vk-auth.postman_collection.json +++ b/tests/e2e/api/collections/bifrost-v1-vk-auth.postman_collection.json @@ -60,7 +60,7 @@ "exec": [ "var timestamp = Date.now();", "var uniqueName = 'VK Auth Test ' + timestamp;", - "pm.request.body.raw = JSON.stringify({name: uniqueName});" + "pm.request.body.raw = JSON.stringify({name: uniqueName, provider_configs: [{provider: 'openai', weight: 1.0, allowed_models: ['*'], key_ids: ['*']}]});" ] } }, diff --git a/tests/e2e/api/collections/bifrost-v1-vk-routing.postman_collection.json b/tests/e2e/api/collections/bifrost-v1-vk-routing.postman_collection.json index 669349c46a..f003e19d6a 100644 --- a/tests/e2e/api/collections/bifrost-v1-vk-routing.postman_collection.json +++ b/tests/e2e/api/collections/bifrost-v1-vk-routing.postman_collection.json @@ -26,7 +26,8 @@ " provider_configs: [{", " provider: pm.collectionVariables.get('provider') || 'openai',", " weight: 1.0,", - " allowed_models: [pm.collectionVariables.get('chat_model') || 'gpt-4o']", + " allowed_models: ['*'],", + " key_ids: ['*']", " }]", "};", "pm.request.body.raw = JSON.stringify(body);" diff --git a/tests/e2e/core/fixtures/test-data.fixture.ts b/tests/e2e/core/fixtures/test-data.fixture.ts index 3b3d6018bf..43fe37aef1 100644 --- a/tests/e2e/core/fixtures/test-data.fixture.ts +++ b/tests/e2e/core/fixtures/test-data.fixture.ts @@ -80,7 +80,7 @@ export class TestDataFactory { return { name: this.uniqueId('key'), value: `sk-test-${this.uniqueId()}`, - models: [], + models: ['*'], weight: 1.0, ...overrides, } @@ -155,8 +155,8 @@ export class TestDataFactory { return { provider: 'openai', weight: 1.0, - allowedModels: [], - keyIds: [], + allowedModels: ['*'], + keyIds: ['*'], ...overrides, } } diff --git a/tests/e2e/features/providers/providers.data.ts b/tests/e2e/features/providers/providers.data.ts index 763a29e094..3e80f31f45 100644 --- a/tests/e2e/features/providers/providers.data.ts +++ b/tests/e2e/features/providers/providers.data.ts @@ -8,7 +8,7 @@ export function createProviderKeyData(overrides: Partial = {} return { name: `Test Key ${timestamp}`, value: `sk-test-${timestamp}-${Math.random().toString(36).substring(7)}`, - models: [], + models: ['*'], weight: 1.0, ...overrides, } diff --git a/tests/e2e/features/virtual-keys/pages/virtual-keys.page.ts b/tests/e2e/features/virtual-keys/pages/virtual-keys.page.ts index 69726a7b39..58a4d877b2 100644 --- a/tests/e2e/features/virtual-keys/pages/virtual-keys.page.ts +++ b/tests/e2e/features/virtual-keys/pages/virtual-keys.page.ts @@ -60,6 +60,7 @@ export interface ProviderConfig { provider: string weight?: number allowedModels?: string[] + keyIds?: string[] budget?: BudgetConfig rateLimit?: RateLimitConfig } diff --git a/tests/e2e/features/virtual-keys/virtual-keys.data.ts b/tests/e2e/features/virtual-keys/virtual-keys.data.ts index 23d2e49c39..5d3a8fb007 100644 --- a/tests/e2e/features/virtual-keys/virtual-keys.data.ts +++ b/tests/e2e/features/virtual-keys/virtual-keys.data.ts @@ -30,6 +30,8 @@ export function createVirtualKeyWithProvider( { provider, weight: 1.0, + allowedModels: ['*'], + keyIds: ['*'], }, ], ...vkOverrides, @@ -87,6 +89,8 @@ export function createVirtualKeyWithMultipleProviders( providerConfigs: providers.map((provider) => ({ provider, weight, + allowedModels: ['*'], + keyIds: ['*'], })), ...vkOverrides, } @@ -99,6 +103,8 @@ export function createProviderConfig(overrides: Partial = {}): P return { provider: 'openai', weight: 1.0, + allowedModels: ['*'], + keyIds: ['*'], ...overrides, } } diff --git a/tests/governance/advancedscenarios_test.go b/tests/governance/advancedscenarios_test.go index f613cbfe0d..aa928c7dc9 100644 --- a/tests/governance/advancedscenarios_test.go +++ b/tests/governance/advancedscenarios_test.go @@ -976,8 +976,10 @@ func TestProviderConfigBudgetUpdateAfterExhaustion(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: initialBudget, ResetDuration: "1h", @@ -1063,9 +1065,11 @@ func TestProviderConfigBudgetUpdateAfterExhaustion(t *testing.T) { Body: UpdateVirtualKeyRequest{ ProviderConfigs: []ProviderConfigRequest{ { - ID: &providerConfigID, - Provider: "openai", - Weight: 1.0, + ID: &providerConfigID, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: newBudget, ResetDuration: "1h", @@ -1132,8 +1136,10 @@ func TestVKDeletionCascadeComplete(t *testing.T) { }, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: 5.0, ResetDuration: "1h", diff --git a/tests/governance/config.json b/tests/governance/config.json index fc64c8e74c..bd9080a064 100644 --- a/tests/governance/config.json +++ b/tests/governance/config.json @@ -7,6 +7,7 @@ "name": "OpenAI Test Key", "value": "env.OPENAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -20,6 +21,7 @@ "name": "Anthropic Test Key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -32,7 +34,8 @@ { "name": "OpenRouter Test Key", "value": "env.OPENROUTER_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { diff --git a/tests/governance/configupdatesync_test.go b/tests/governance/configupdatesync_test.go index 6f2abd87c4..994d0448f6 100644 --- a/tests/governance/configupdatesync_test.go +++ b/tests/governance/configupdatesync_test.go @@ -421,8 +421,10 @@ func TestProviderRateLimitUpdateSyncToMemory(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &initialTokenLimit, TokenResetDuration: &tokenResetDuration, @@ -512,9 +514,11 @@ func TestProviderRateLimitUpdateSyncToMemory(t *testing.T) { Body: UpdateVirtualKeyRequest{ ProviderConfigs: []ProviderConfigRequest{ { - ID: &providerConfigID, - Provider: "openai", - Weight: 1.0, + ID: &providerConfigID, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &newLowerLimit, TokenResetDuration: &tokenResetDuration, @@ -955,8 +959,10 @@ func TestProviderBudgetUpdateSyncToMemory(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: initialBudget, ResetDuration: resetDuration, @@ -1054,9 +1060,11 @@ func TestProviderBudgetUpdateSyncToMemory(t *testing.T) { Body: UpdateVirtualKeyRequest{ ProviderConfigs: []ProviderConfigRequest{ { - ID: &providerConfigID, - Provider: "openai", - Weight: 1.0, + ID: &providerConfigID, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: newLowerBudget, ResetDuration: resetDuration, diff --git a/tests/governance/e2e_test.go b/tests/governance/e2e_test.go index ae56ffe0dd..bab26fff30 100644 --- a/tests/governance/e2e_test.go +++ b/tests/governance/e2e_test.go @@ -252,8 +252,10 @@ func TestFullBudgetHierarchyEnforcement(t *testing.T) { }, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: providerBudget, ResetDuration: "1h", @@ -1348,13 +1350,15 @@ func TestWeightedProviderLoadBalancing(t *testing.T) { ProviderConfigs: []ProviderConfigRequest{ { Provider: "openai", - Weight: openaiWeight, + Weight: &openaiWeight, AllowedModels: []string{"gpt-4o"}, + KeyIDs: []string{"*"}, }, { Provider: "azure", - Weight: azureWeight, + Weight: &azureWeight, AllowedModels: []string{"gpt-4o"}, + KeyIDs: []string{"*"}, }, }, }, @@ -1422,7 +1426,7 @@ func TestWeightedProviderLoadBalancing(t *testing.T) { // Try to detect which provider was used // Check if model in response contains provider name if provider, ok := resp.Body["extra_fields"].(map[string]interface{})["provider"].(string); ok { - model, ok := resp.Body["extra_fields"].(map[string]interface{})["model_requested"].(string) + model, ok := resp.Body["extra_fields"].(map[string]interface{})["original_model_requested"].(string) if !ok { t.Logf("Request %d failed to get model requested", i+1) continue @@ -1482,13 +1486,15 @@ func TestProviderFallbackMechanism(t *testing.T) { ProviderConfigs: []ProviderConfigRequest{ { Provider: "anthropic", - Weight: anthropicWeight, + Weight: &anthropicWeight, AllowedModels: []string{"claude-3-sonnet"}, // Does NOT include gpt-4o + KeyIDs: []string{"*"}, }, { Provider: "openai", - Weight: openaiWeight, + Weight: &openaiWeight, AllowedModels: []string{"gpt-4o"}, // DOES include gpt-4o + KeyIDs: []string{"*"}, }, }, }, diff --git a/tests/governance/edgecases_test.go b/tests/governance/edgecases_test.go index 13fbed4a6e..32dad1bbe6 100644 --- a/tests/governance/edgecases_test.go +++ b/tests/governance/edgecases_test.go @@ -71,8 +71,10 @@ func TestCrissCrossComplexBudgetHierarchy(t *testing.T) { }, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: 0.08, // Even tighter provider budget ResetDuration: "1h", diff --git a/tests/governance/providerbudget_test.go b/tests/governance/providerbudget_test.go index c39dcbb7fd..7b34437eb7 100644 --- a/tests/governance/providerbudget_test.go +++ b/tests/governance/providerbudget_test.go @@ -23,16 +23,20 @@ func TestProviderBudgetExceeded(t *testing.T) { }, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: 0.01, // Specific OpenAI budget ResetDuration: "1h", }, }, { - Provider: "anthropic", - Weight: 1.0, + Provider: "anthropic", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, Budget: &BudgetRequest{ MaxLimit: 0.01, // Specific Anthropic budget ResetDuration: "1h", diff --git a/tests/governance/ratelimit_test.go b/tests/governance/ratelimit_test.go index aff8eac941..a7a45ffa92 100644 --- a/tests/governance/ratelimit_test.go +++ b/tests/governance/ratelimit_test.go @@ -169,8 +169,10 @@ func TestProviderConfigTokenRateLimit(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &providerTokenLimit, TokenResetDuration: &tokenResetDuration, @@ -248,8 +250,10 @@ func TestProviderConfigRequestRateLimit(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ RequestMaxLimit: &providerRequestLimit, RequestResetDuration: &requestResetDuration, @@ -328,16 +332,20 @@ func TestMultipleProvidersSeparateRateLimits(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &openaiLimit, TokenResetDuration: &tokenResetDuration, }, }, { - Provider: "anthropic", - Weight: 1.0, + Provider: "anthropic", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &anthropicLimit, TokenResetDuration: &tokenResetDuration, @@ -400,8 +408,10 @@ func TestProviderAndVKRateLimitTogether(t *testing.T) { }, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &providerTokenLimit, TokenResetDuration: &providerTokenResetDuration, @@ -840,16 +850,20 @@ func TestProviderLevelRateLimitUsageTracking(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &openaiTokenLimit, TokenResetDuration: &tokenResetDuration, }, }, { - Provider: "anthropic", - Weight: 1.0, + Provider: "anthropic", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &anthropicTokenLimit, TokenResetDuration: &tokenResetDuration, diff --git a/tests/governance/ratelimitenforcement_test.go b/tests/governance/ratelimitenforcement_test.go index ab8ab6ca92..006a836f1b 100644 --- a/tests/governance/ratelimitenforcement_test.go +++ b/tests/governance/ratelimitenforcement_test.go @@ -230,8 +230,10 @@ func TestProviderConfigTokenRateLimitEnforcement(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ TokenMaxLimit: &providerTokenLimit, TokenResetDuration: &tokenResetDuration, @@ -357,8 +359,10 @@ func TestProviderConfigRequestRateLimitEnforcement(t *testing.T) { Name: vkName, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ RequestMaxLimit: &providerRequestLimit, RequestResetDuration: &requestResetDuration, @@ -452,8 +456,10 @@ func TestProviderAndVKRateLimitBothEnforced(t *testing.T) { }, ProviderConfigs: []ProviderConfigRequest{ { - Provider: "openai", - Weight: 1.0, + Provider: "openai", + Weight: float64Ptr(1.0), + AllowedModels: []string{"*"}, + KeyIDs: []string{"*"}, RateLimit: &CreateRateLimitRequest{ RequestMaxLimit: &providerRequestLimit, RequestResetDuration: &requestResetDuration, diff --git a/tests/governance/test_utils.go b/tests/governance/test_utils.go index 2e415250ce..487598ef46 100644 --- a/tests/governance/test_utils.go +++ b/tests/governance/test_utils.go @@ -221,12 +221,18 @@ type CreateVirtualKeyRequest struct { type ProviderConfigRequest struct { ID *uint `json:"id,omitempty"` Provider string `json:"provider"` - Weight float64 `json:"weight,omitempty"` + Weight *float64 `json:"weight,omitempty"` AllowedModels []string `json:"allowed_models,omitempty"` + KeyIDs []string `json:"key_ids,omitempty"` Budget *BudgetRequest `json:"budget,omitempty"` RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` } +// float64Ptr returns a pointer to a float64 value +func float64Ptr(v float64) *float64 { + return &v +} + // BudgetRequest represents a budget request type BudgetRequest struct { MaxLimit float64 `json:"max_limit"` diff --git a/tests/integrations/python/config.json b/tests/integrations/python/config.json index 73f32f8726..00b89b5bdb 100644 --- a/tests/integrations/python/config.json +++ b/tests/integrations/python/config.json @@ -23,6 +23,7 @@ "name": "OpenAI API Key", "value": "env.OPENAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -36,6 +37,7 @@ "name": "ElevenLabs API Key", "value": "env.ELEVENLABS_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": false } ], @@ -49,6 +51,7 @@ "name": "Xai API Key", "value": "env.XAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": false } ], @@ -62,6 +65,7 @@ "name": "Hugging Face API Key", "value": "env.HUGGING_FACE_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": false } ], @@ -75,6 +79,7 @@ "name": "Anthropic API Key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -88,6 +93,7 @@ "name": "Gemini API Key", "value": "env.GEMINI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -104,7 +110,8 @@ "region": "env.GOOGLE_LOCATION", "auth_credentials": "env.VERTEX_CREDENTIALS" }, - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -116,7 +123,8 @@ { "name": "Mistral API Key", "value": "env.MISTRAL_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -128,7 +136,8 @@ { "name": "Cohere API Key", "value": "env.COHERE_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -140,7 +149,8 @@ { "name": "Parasail API Key", "value": "env.PARASAIL_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -152,7 +162,8 @@ { "name": "Groq API Key", "value": "env.GROQ_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -164,7 +175,8 @@ { "name": "Perplexity API Key", "value": "env.PERPLEXITY_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -176,7 +188,8 @@ { "name": "Cerebras API Key", "value": "env.CEREBRAS_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -188,7 +201,8 @@ { "name": "OpenRouter API Key", "value": "env.OPENROUTER_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -214,7 +228,8 @@ "gpt-image-1": "gpt-image-1" } }, - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -232,6 +247,7 @@ "arn": "env.AWS_ARN" }, "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -296,7 +312,25 @@ "name": "Test Key", "id": "vk-test", "value": "sk-bf-test-key", - "is_active": true + "is_active": true, + "provider_configs": [ + { "provider": "openai", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "elevenlabs", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "xai", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "huggingface", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "anthropic", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "gemini", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "vertex", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "mistral", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "cohere", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "parasail", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "groq", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "perplexity", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "cerebras", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "openrouter", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "azure", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "bedrock", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 } + ] } ] }, @@ -312,4 +346,4 @@ "max_request_body_size_mb": 100, "enable_litellm_fallbacks": false } -} \ No newline at end of file +} diff --git a/tests/integrations/python/config.yml b/tests/integrations/python/config.yml index 6b6b489ac2..f1b01580b0 100644 --- a/tests/integrations/python/config.yml +++ b/tests/integrations/python/config.yml @@ -174,6 +174,8 @@ providers: thinking: "us.anthropic.claude-opus-4-5-20251101-v1:0" text_completion: "mistral.mistral-7b-instruct-v0:2" embeddings: "global.cohere.embed-v4:0" + image_generation: "amazon.titan-image-generator-v2:0" + image_variation: "amazon.titan-image-generator-v2:0" batch_inline: "anthropic.claude-3-5-sonnet-20240620-v1:0" image_edit: "amazon.nova-canvas-v1:0" batch_list: "anthropic.claude-3-5-sonnet-20240620-v1:0" @@ -517,9 +519,11 @@ provider_scenarios: file_retrieve: true # Bedrock retrieves S3 object metadata file_delete: true # Bedrock deletes S3 objects file_content: true # Bedrock downloads S3 object content - image_edit: true # Bedrock supports image editing via Nova Canvas + image_generation: true # Bedrock supports image generation via invoke (Titan, SA, cross-provider) + image_edit: true # Bedrock supports image editing via invoke (Titan, SA) + image_variation: true # Bedrock supports image variation via invoke (Titan IMAGE_VARIATION) count_tokens: true # Bedrock supports token counting via CountTokens API - + cohere: simple_chat: true multi_turn_conversation: true diff --git a/tests/integrations/python/pyproject.toml b/tests/integrations/python/pyproject.toml index 70e3167a4b..8d49b81278 100644 --- a/tests/integrations/python/pyproject.toml +++ b/tests/integrations/python/pyproject.toml @@ -24,12 +24,12 @@ dependencies = [ # AI/ML SDK dependencies "openai>=1.30.0", "anthropic>=0.25.0", - "litellm>=1.80.5", - "langchain-openai>=0.1.0", - "langchain-core>=0.3.0", - "langchain-anthropic>=0.1.0", + "litellm==1.80.5", + "langchain-openai==0.1.0", + "langchain-core==0.3.81", + "langchain-anthropic==0.1.0", "langchain-google-genai==4.1.1", - "langchain-mistralai>=0.1.0", + "langchain-mistralai==0.1.0", "langgraph>=0.1.0", "mistralai>=0.4.0", "google-genai>=1.50.0", @@ -123,4 +123,4 @@ exclude_lines = [ [tool.uv] -exclude-newer = "7 days" \ No newline at end of file +exclude-newer = "2026-04-08" \ No newline at end of file diff --git a/tests/integrations/python/tests/test_bedrock.py b/tests/integrations/python/tests/test_bedrock.py index b5a2a4503e..8640390e33 100644 --- a/tests/integrations/python/tests/test_bedrock.py +++ b/tests/integrations/python/tests/test_bedrock.py @@ -45,6 +45,29 @@ 26. Count tokens with tool definitions - Cross-provider 27. Count tokens from long text - Cross-provider 28. Count tokens from multi-turn conversation - Cross-provider + +Invoke Endpoint β€” Image Generation Tests (TestBedrockInvokeEndpoint): +29. Titan image generation via invoke (taskType=TEXT_IMAGE) +30. Titan embeddings via invoke (inputText) +31. Titan embeddings with params via invoke (inputText + params) +32. Cohere embeddings via invoke (texts array) +33. Titan inpainting via invoke (taskType=INPAINTING) +34. Titan outpainting via invoke (taskType=OUTPAINTING) +35. Titan background removal via invoke (taskType=BACKGROUND_REMOVAL) +36. Titan image variation via invoke (taskType=IMAGE_VARIATION) +37. Stability AI image inpaint via invoke (image+mask) +38. Vertex Imagen image generation via invoke (cross-provider) +39. OpenAI gpt-image-1 via invoke (cross-provider) +40. Titan text generation via invoke (inputText+textGenerationConfig, not misrouted as embedding) +41. Cohere embeddings via invoke with inputs payload (mixed text+image, not misrouted as text completion) +42. Cohere embeddings via invoke with explicit embedding_types=["float"] +43. Cohere embeddings via invoke with embedding_types=["int8"] (regression: was silently dropped) +44. Cohere embeddings via invoke with embedding_types=["uint8"] (regression: was silently dropped) +45. Cohere embeddings via invoke with embedding_types=["float","int8"] (multi-type, none dropped) +46. Anthropic claude via invoke with messages array (ResponsesRequest path β†’ Anthropic Messages format) +47. Nova via invoke with messages array (ResponsesRequest path β†’ Converse/Nova format) +48. AI21 Jamba via invoke with messages array (ResponsesRequest path β†’ AI21 Choices format) +49. Anthropic claude via invoke-with-response-stream with messages (ResponsesRequest streaming path) """ import base64 @@ -54,10 +77,12 @@ from typing import Any, Dict, List import boto3 +import botocore.exceptions import pytest from .utils.common import ( - BASE64_IMAGE, + BASE64_IMAGE_LARGE, + BASE64_TITAN_MASK_IMAGE, CALCULATOR_TOOL, LOCATION_KEYWORDS, MULTI_TURN_MESSAGES, @@ -335,9 +360,9 @@ def extract_system_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, An class TestBedrockIntegration: """Test suite for Bedrock integration covering core scenarios""" + @pytest.mark.skip(reason="Skipping text completion invoke test") @skip_if_no_api_key("bedrock") def test_01_text_completion_invoke(self, bedrock_client, test_config): - pytest.skip("Skipping text completion invoke test") model_id = get_model("bedrock", "text_completion") request_body = { @@ -474,7 +499,7 @@ def test_03_image_analysis(self, bedrock_client, test_config, provider, model): }, { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{BASE64_IMAGE}"}, + "image_url": {"url": f"data:image/png;base64,{BASE64_IMAGE_LARGE}"}, }, ], } @@ -2038,3 +2063,833 @@ def test_28_count_tokens_multi_turn_conversation(self, bedrock_client, provider, ), f"Multi-turn conversation should have >15 tokens, got {response['inputTokens']}" print(f"βœ“ Multi-turn conversation token count: {response['inputTokens']} tokens") + + +# --------------------------------------------------------------------------- +# Invoke Endpoint β€” Image Generation, Image Edit, Image Variation, Embeddings +# --------------------------------------------------------------------------- +# These tests exercise the /bedrock/model/{modelId}/invoke route using +# native Bedrock payload formats (taskType-based for Titan/Nova Canvas, +# flat-field for Stability AI) as well as cross-provider model IDs +# (vertex/..., openai/...) routed through the same invoke endpoint. +# --------------------------------------------------------------------------- + +def _assert_invoke_images(response_body: dict, min_images: int = 1) -> None: + """Assert that an invoke response contains at least min_images base64 images.""" + images = response_body.get("images") or [] + assert isinstance(images, list), ( + f"Expected 'images' to be a list, got {type(images).__name__}. " + f"Response keys: {list(response_body.keys())}" + ) + assert len(images) >= min_images, ( + f"Expected at least {min_images} image(s) in response, got {len(images)}. " + f"Response keys: {list(response_body.keys())}" + ) + for i, img in enumerate(images): + assert isinstance(img, str) and len(img) > 0, f"Image {i} is not a non-empty string" + print(f" βœ“ {len(images)} image(s) returned") + + +def _assert_invoke_embedding(response_body: dict) -> None: + """Assert that an invoke response contains a non-empty embedding vector.""" + embedding = response_body.get("embedding") or [] + assert len(embedding) > 0, ( + f"Expected 'embedding' array in response, got keys: {list(response_body.keys())}" + ) + assert all(isinstance(v, (int, float)) for v in embedding), "Embedding must be numeric" + print(f" βœ“ embedding dim={len(embedding)}") + + +class TestBedrockInvokeEndpoint: + """ + Tests for the Bedrock /invoke and /invoke-with-response-stream endpoints. + + Covers native Bedrock payload formats for: + - Image generation (Titan TEXT_IMAGE, Stability AI, Vertex Imagen, OpenAI) + - Image editing (Titan INPAINTING / OUTPAINTING / BACKGROUND_REMOVAL, SA inpaint) + - Image variation (Titan IMAGE_VARIATION) + - Embeddings (Titan embed text v2, Cohere embed English v3) + - Messages path (Anthropic, Nova, AI21 Jamba via messages array β†’ ResponsesRequest) + - Messages streaming (invoke-with-response-stream with messages array) + """ + + # ------------------------------------------------------------------ # + # 29. Titan image generation # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_29_invoke_titan_image_generation(self, bedrock_client): + """Test Case 29: Titan Image Generator v2 via invoke β€” taskType=TEXT_IMAGE""" + print("\n=== Test 29: Titan image generation via invoke ===") + + body = { + "taskType": "TEXT_IMAGE", + "textToImageParams": { + "text": "a serene mountain lake at sunset", + "negativeText": "blurry, low quality", + }, + "imageGenerationConfig": { + "numberOfImages": 1, + "width": 512, + "height": 512, + }, + } + + response = bedrock_client.invoke_model( + modelId="amazon.titan-image-generator-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 30. Titan embed text v2 # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_30_invoke_titan_embeddings(self, bedrock_client): + """Test Case 30: Titan Embed Text v2 via invoke β€” inputText""" + print("\n=== Test 30: Titan embeddings via invoke ===") + + body = {"inputText": "the quick brown fox jumps over the lazy dog"} + + response = bedrock_client.invoke_model( + modelId="amazon.titan-embed-text-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_embedding(out) + # inputTextTokenCount is returned by the Bedrock-format response + assert "inputTextTokenCount" in out, ( + f"Expected 'inputTextTokenCount' in Titan embed response, got: {list(out.keys())}" + ) + print(f" βœ“ inputTextTokenCount={out['inputTextTokenCount']}") + + # ------------------------------------------------------------------ # + # 31. Titan embed with dimensions + normalize # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_31_invoke_titan_embeddings_with_params(self, bedrock_client): + """Test Case 31: Titan Embed Text v2 via invoke β€” dimensions + normalize""" + print("\n=== Test 31: Titan embeddings with params via invoke ===") + + body = { + "inputText": "machine learning and artificial intelligence", + "dimensions": 256, + "normalize": True, + } + + response = bedrock_client.invoke_model( + modelId="amazon.titan-embed-text-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_embedding(out) + assert len(out["embedding"]) == 256, ( + f"Expected 256-dim embedding, got {len(out['embedding'])}" + ) + + # ------------------------------------------------------------------ # + # 32. Cohere embed English v3 # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_32_invoke_cohere_embeddings(self, bedrock_client): + """Test Case 32: Cohere Embed English v3 via invoke β€” texts array""" + print("\n=== Test 32: Cohere embeddings via invoke ===") + + body = { + "texts": ["hello world", "goodbye world"], + "input_type": "search_document", + } + + response = bedrock_client.invoke_model( + modelId="cohere.embed-english-v3", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + # Cohere native response uses "embeddings" (plural list-of-lists) + # Bifrost may return Titan-compat single "embedding" for invoke + if "embedding" in out: + _assert_invoke_embedding(out) + else: + embeddings = out.get("embeddings") + assert isinstance(embeddings, list) and len(embeddings) == len(body["texts"]), ( + f"Expected {len(body['texts'])} embeddings, got: {out}" + ) + for i, vector in enumerate(embeddings): + assert isinstance(vector, list) and len(vector) > 0, f"Embedding {i} is empty" + assert all(isinstance(v, (int, float)) for v in vector), ( + f"Embedding {i} must be numeric" + ) + print(f" βœ“ Cohere embedding response keys: {list(out.keys())}") + + # ------------------------------------------------------------------ # + # 33. Titan INPAINTING # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_33_invoke_titan_inpainting(self, bedrock_client): + """Test Case 33: Titan Image Generator v2 via invoke β€” INPAINTING""" + print("\n=== Test 33: Titan INPAINTING via invoke ===") + + body = { + "taskType": "INPAINTING", + "inPaintingParams": { + "image": BASE64_IMAGE_LARGE, + "maskImage": BASE64_TITAN_MASK_IMAGE, + "text": "a beautiful garden with flowers", + "negativeText": "blurry", + }, + "imageGenerationConfig": {"numberOfImages": 1}, + } + + response = bedrock_client.invoke_model( + modelId="amazon.titan-image-generator-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 34. Titan OUTPAINTING # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_34_invoke_titan_outpainting(self, bedrock_client): + """Test Case 34: Titan Image Generator v2 via invoke β€” OUTPAINTING""" + print("\n=== Test 34: Titan OUTPAINTING via invoke ===") + + body = { + "taskType": "OUTPAINTING", + "outPaintingParams": { + "image": BASE64_IMAGE_LARGE, + "maskImage": BASE64_TITAN_MASK_IMAGE, + "text": "extend the scene with a meadow", + "outPaintingMode": "DEFAULT", + }, + "imageGenerationConfig": {"numberOfImages": 1}, + } + + response = bedrock_client.invoke_model( + modelId="amazon.titan-image-generator-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 35. Titan BACKGROUND_REMOVAL # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_35_invoke_titan_background_removal(self, bedrock_client): + """Test Case 35: Titan Image Generator v2 via invoke β€” BACKGROUND_REMOVAL""" + print("\n=== Test 35: Titan BACKGROUND_REMOVAL via invoke ===") + + body = { + "taskType": "BACKGROUND_REMOVAL", + "backgroundRemovalParams": {"image": BASE64_IMAGE_LARGE}, + } + + response = bedrock_client.invoke_model( + modelId="amazon.titan-image-generator-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 36. Titan IMAGE_VARIATION # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_36_invoke_titan_image_variation(self, bedrock_client): + """Test Case 36: Titan Image Generator v2 via invoke β€” IMAGE_VARIATION""" + print("\n=== Test 36: Titan IMAGE_VARIATION via invoke ===") + + body = { + "taskType": "IMAGE_VARIATION", + "imageVariationParams": { + "images": [BASE64_IMAGE_LARGE], + "text": "same style with a different color palette", + "similarityStrength": 0.7, + }, + "imageGenerationConfig": {"numberOfImages": 1}, + } + + response = bedrock_client.invoke_model( + modelId="amazon.titan-image-generator-v2:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 37. Stability AI β€” image inpaint # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_37_invoke_stability_ai_inpaint(self, bedrock_client): + """Test Case 37: Stability AI stable-image-inpaint via invoke β€” image+mask+prompt""" + print("\n=== Test 37: Stability AI inpaint via invoke ===") + + body = { + "image": BASE64_IMAGE_LARGE, + "mask": BASE64_IMAGE_LARGE, + "prompt": "replace masked area with flowers", + "output_format": "png", + } + + response = bedrock_client.invoke_model( + modelId="us.stability.stable-image-inpaint-v1:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 38. Vertex Imagen β€” cross-provider via invoke # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("vertex") + def test_38_invoke_vertex_imagen(self, bedrock_client): + """Test Case 38: Vertex Imagen 4 via Bedrock invoke endpoint (cross-provider)""" + print("\n=== Test 38: Vertex Imagen via invoke ===") + + body = {"prompt": "a gecko resting on a tropical leaf"} + + response = bedrock_client.invoke_model( + modelId="vertex/imagen-4.0-generate-001", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 39. OpenAI gpt-image-1 β€” cross-provider via invoke # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("openai") + def test_39_invoke_openai_image_generation(self, bedrock_client): + """Test Case 39: OpenAI gpt-image-1 via Bedrock invoke endpoint (cross-provider)""" + print("\n=== Test 39: OpenAI gpt-image-1 via invoke ===") + + body = { + "prompt": "a gecko resting on a tropical leaf", + "n": 1, + "quality": "low", + } + + response = bedrock_client.invoke_model( + modelId="openai/gpt-image-1", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + _assert_invoke_images(out) + + # ------------------------------------------------------------------ # + # 40. Titan text generation β€” inputText must NOT route as embedding # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_40_invoke_titan_text_generation(self, bedrock_client): + """Test Case 40: Titan Text via invoke β€” inputText must not be misrouted as embedding. + + Regression test for the bug where DetectInvokeRequestType returned EmbeddingRequest for any + body with 'inputText', regardless of model. Detection is now model-ID-based: only models + whose ID contains 'embed' are routed as embeddings. The response must contain 'results'. + """ + print("\n=== Test 40: Titan text generation via invoke (not embedding) ===") + + # Intentionally omit textGenerationConfig to cover the bare-inputText case β€” + # the fix must use model ID (not body shape) to distinguish text-gen from embedding. + body = { + "inputText": "What is the capital of France? Answer in one word.", + } + + try: + response = bedrock_client.invoke_model( + modelId="amazon.titan-text-express-v1", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + except botocore.exceptions.ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code in ("ResourceNotFoundException", "ValidationException"): + pytest.skip(f"Titan text model no longer available: {e}") + raise + out = json.loads(response["body"].read()) + + assert "embedding" not in out, ( + f"Request was misrouted to the embedding path β€” response contains 'embedding' key. " + f"Response keys: {list(out.keys())}" + ) + assert "results" in out, ( + f"Expected 'results' in Titan text generation response, got: {list(out.keys())}" + ) + results = out["results"] + assert len(results) > 0 and results[0].get("outputText"), ( + f"Expected non-empty outputText in results, got: {results}" + ) + print(f" βœ“ outputText={results[0]['outputText'][:60]!r}") + + # ------------------------------------------------------------------ # + # 41. Cohere embed β€” inputs payload must NOT route as text completion # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_41_invoke_cohere_embeddings_inputs(self, bedrock_client): + """Test Case 41: Cohere Embed via invoke β€” inputs payload must not be misrouted as text completion. + + Regression test for the bug where DetectInvokeRequestType only checked for the 'texts' field + when detecting Cohere embeddings. Requests using the 'inputs' field (mixed text+image payloads) + fell through to TextCompletionRequest. Detection must be model-ID-based (contains 'embed') + and cover all Cohere embedding payload shapes: 'texts', 'inputs', and 'images'. + """ + print("\n=== Test 41: Cohere embeddings via invoke (inputs payload, not text completion) ===") + + # Use 'inputs' field instead of 'texts' β€” this is the payload shape that was misrouted + body = { + "inputs": [{"text": "hello world"}, {"text": "goodbye world"}], + "input_type": "search_document", + } + + response = bedrock_client.invoke_model( + modelId="cohere.embed-english-v3", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + + has_embedding = "embedding" in out or "embeddings" in out + assert has_embedding, ( + f"Request was misrouted β€” expected 'embedding' or 'embeddings' key but got: {list(out.keys())}. " + f"If 'results' is present, the request was routed to text completion instead of embeddings." + ) + print(f" βœ“ Cohere inputs embedding response keys: {list(out.keys())}") + + # ------------------------------------------------------------------ # + # 42. Cohere embed β€” embedding_types float (explicit) # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_42_invoke_cohere_embedding_type_float(self, bedrock_client): + """Test Case 42: Cohere Embed via invoke β€” explicit embedding_types=["float"]. + + Verifies that requesting a single float encoding returns the expected + embeddings_by_type response structure with float vectors. + """ + print("\n=== Test 42: Cohere embedding_types float ===") + + body = { + "texts": ["the quick brown fox"], + "input_type": "search_document", + "embedding_types": ["float"], + } + + response = bedrock_client.invoke_model( + modelId="cohere.embed-english-v3", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + + assert out.get("response_type") == "embeddings_by_type", ( + f"Expected response_type='embeddings_by_type', got: {out.get('response_type')}" + ) + embeddings = out.get("embeddings", {}) + assert "float" in embeddings, f"Expected 'float' key in embeddings, got: {list(embeddings.keys())}" + float_vecs = embeddings["float"] + assert isinstance(float_vecs, list) and len(float_vecs) == 1, ( + f"Expected 1 float vector, got: {float_vecs}" + ) + assert isinstance(float_vecs[0], list) and len(float_vecs[0]) > 0, "Float vector is empty" + assert all(isinstance(v, float) for v in float_vecs[0]), "Float vector must contain floats" + print(f" βœ“ float embedding dim={len(float_vecs[0])}") + + # ------------------------------------------------------------------ # + # 43. Cohere embed β€” embedding_types int8 # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_43_invoke_cohere_embedding_type_int8(self, bedrock_client): + """Test Case 43: Cohere Embed via invoke β€” embedding_types=["int8"]. + + Regression test for the bug where int8 (and other non-float encoding types) + were silently dropped because the embeddings_by_type parser only declared + 'float' and 'base64' fields in its anonymous struct. + """ + print("\n=== Test 43: Cohere embedding_types int8 ===") + + body = { + "texts": ["the quick brown fox"], + "input_type": "search_document", + "embedding_types": ["int8"], + } + + response = bedrock_client.invoke_model( + modelId="cohere.embed-english-v3", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + + assert out.get("response_type") == "embeddings_by_type", ( + f"Expected response_type='embeddings_by_type', got: {out.get('response_type')}" + ) + embeddings = out.get("embeddings", {}) + assert "int8" in embeddings, ( + f"Expected 'int8' key in embeddings β€” was it silently dropped? Got: {list(embeddings.keys())}" + ) + int8_vecs = embeddings["int8"] + assert isinstance(int8_vecs, list) and len(int8_vecs) == 1, ( + f"Expected 1 int8 vector, got: {int8_vecs}" + ) + assert isinstance(int8_vecs[0], list) and len(int8_vecs[0]) > 0, "int8 vector is empty" + assert all(isinstance(v, int) and -128 <= v <= 127 for v in int8_vecs[0]), ( + "int8 vector values must be integers in [-128, 127]" + ) + print(f" βœ“ int8 embedding dim={len(int8_vecs[0])}") + + # ------------------------------------------------------------------ # + # 44. Cohere embed β€” embedding_types uint8 # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_44_invoke_cohere_embedding_type_uint8(self, bedrock_client): + """Test Case 44: Cohere Embed via invoke β€” embedding_types=["uint8"]. + + Verifies that uint8 encoding is not dropped (previously silently lost + because the parser mapped []uint8 as base64 via json.Marshal). + """ + print("\n=== Test 44: Cohere embedding_types uint8 ===") + + body = { + "texts": ["the quick brown fox"], + "input_type": "search_document", + "embedding_types": ["uint8"], + } + + response = bedrock_client.invoke_model( + modelId="cohere.embed-english-v3", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + + assert out.get("response_type") == "embeddings_by_type", ( + f"Expected response_type='embeddings_by_type', got: {out.get('response_type')}" + ) + embeddings = out.get("embeddings", {}) + assert "uint8" in embeddings, ( + f"Expected 'uint8' key in embeddings β€” was it silently dropped? Got: {list(embeddings.keys())}" + ) + uint8_vecs = embeddings["uint8"] + assert isinstance(uint8_vecs, list) and len(uint8_vecs) == 1, ( + f"Expected 1 uint8 vector, got: {uint8_vecs}" + ) + assert isinstance(uint8_vecs[0], list) and len(uint8_vecs[0]) > 0, "uint8 vector is empty" + assert all(isinstance(v, int) and 0 <= v <= 255 for v in uint8_vecs[0]), ( + "uint8 vector values must be integers in [0, 255]" + ) + print(f" βœ“ uint8 embedding dim={len(uint8_vecs[0])}") + + # ------------------------------------------------------------------ # + # 45. Cohere embed β€” multiple embedding_types (float + int8) # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_45_invoke_cohere_embedding_types_multi(self, bedrock_client): + """Test Case 45: Cohere Embed via invoke β€” embedding_types=["float", "int8"]. + + Verifies that multiple encoding types are all returned without any being + dropped, and that each type contains the correct number of vectors. + """ + print("\n=== Test 45: Cohere embedding_types multi (float + int8) ===") + + texts = ["the quick brown fox", "machine learning"] + body = { + "texts": texts, + "input_type": "search_document", + "embedding_types": ["float", "int8"], + } + + response = bedrock_client.invoke_model( + modelId="cohere.embed-english-v3", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + + assert out.get("response_type") == "embeddings_by_type", ( + f"Expected response_type='embeddings_by_type', got: {out.get('response_type')}" + ) + embeddings = out.get("embeddings", {}) + for enc_type in ("float", "int8"): + assert enc_type in embeddings, ( + f"Expected '{enc_type}' key in embeddings β€” was it dropped? Got: {list(embeddings.keys())}" + ) + vecs = embeddings[enc_type] + assert isinstance(vecs, list) and len(vecs) == len(texts), ( + f"Expected {len(texts)} {enc_type} vectors, got {len(vecs)}" + ) + for i, vec in enumerate(vecs): + assert isinstance(vec, list) and len(vec) > 0, f"{enc_type} vector {i} is empty" + print(f" βœ“ float dim={len(embeddings['float'][0])}, int8 dim={len(embeddings['int8'][0])}") + + # ------------------------------------------------------------------ # + # 46. Anthropic claude β€” messages path via invoke # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_46_invoke_anthropic_messages(self, bedrock_client): + """Test Case 46: Anthropic Claude via invoke with messages array β†’ ResponsesRequest path. + + Verifies that a payload containing a 'messages' array is detected as ResponsesRequest + (not TextCompletionRequest) and returns the Anthropic Messages API format: + {"type": "message", "role": "assistant", "content": [...], "stop_reason": "end_turn"}. + """ + print("\n=== Test 46: Anthropic claude via invoke (messages path) ===") + + body = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Say hello in one word."}], + } + ], + "max_tokens": 50, + } + + response = bedrock_client.invoke_model( + modelId="anthropic.claude-3-haiku-20240307-v1:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + out = json.loads(response["body"].read()) + + # Must NOT be text-completion format + assert "results" not in out and "outputs" not in out, ( + f"Request was misrouted to text-completion path β€” got keys: {list(out.keys())}" + ) + # Must be Anthropic Messages API format + assert out.get("type") == "message", ( + f"Expected type='message', got: {out.get('type')}. Keys: {list(out.keys())}" + ) + assert out.get("role") == "assistant", ( + f"Expected role='assistant', got: {out.get('role')}" + ) + content = out.get("content", []) + assert isinstance(content, list) and len(content) > 0, ( + f"Expected non-empty content list, got: {content}" + ) + text_block = next((b for b in content if b.get("type") == "text"), None) + assert text_block is not None and text_block.get("text"), ( + f"Expected a text content block, got: {content}" + ) + assert out.get("stop_reason") in ("end_turn", "max_tokens"), ( + f"Unexpected stop_reason: {out.get('stop_reason')}" + ) + print(f" βœ“ stop_reason={out['stop_reason']!r}, text={text_block['text'][:60]!r}") + + # ------------------------------------------------------------------ # + # 47. Nova β€” messages path via invoke # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_47_invoke_nova_messages(self, bedrock_client): + """Test Case 47: Amazon Nova via invoke with messages array β†’ ResponsesRequest path. + + Nova invoke with a messages array routes through ResponsesRequest and returns + the Converse-compatible format: {"output": {"message": {"role": ..., "content": [...]}}, + "stopReason": "end_turn"}. + """ + print("\n=== Test 47: Nova via invoke (messages path) ===") + + body = { + "messages": [ + { + "role": "user", + "content": [{"text": "Say hello in one word."}], + } + ], + "inferenceConfig": {"maxTokens": 50}, + } + + try: + response = bedrock_client.invoke_model( + modelId="us.amazon.nova-lite-v1:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + except botocore.exceptions.ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code in ("ValidationException", "ResourceNotFoundException"): + pytest.skip(f"Nova model not available with this configuration: {e}") + raise + out = json.loads(response["body"].read()) + + # Must NOT be text-completion format + assert "results" not in out and "outputs" not in out, ( + f"Request was misrouted to text-completion path β€” got keys: {list(out.keys())}" + ) + # Must be Converse-compatible format + assert "output" in out, f"Expected 'output' in response, got: {list(out.keys())}" + msg = out["output"].get("message", {}) + assert msg.get("role") == "assistant", ( + f"Expected role='assistant', got: {msg.get('role')}" + ) + content_blocks = msg.get("content", []) + assert isinstance(content_blocks, list) and len(content_blocks) > 0, ( + f"Expected non-empty content blocks, got: {content_blocks}" + ) + text_block = next((b for b in content_blocks if "text" in b), None) + assert text_block is not None and text_block["text"], ( + f"Expected a text content block, got: {content_blocks}" + ) + assert out.get("stopReason") in ("end_turn", "max_tokens"), ( + f"Unexpected stopReason: {out.get('stopReason')}" + ) + print(f" βœ“ stopReason={out['stopReason']!r}, text={text_block['text'][:60]!r}") + + # ------------------------------------------------------------------ # + # 48. AI21 Jamba β€” messages path via invoke # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_48_invoke_ai21_messages(self, bedrock_client): + """Test Case 48: AI21 Jamba via invoke with messages array β†’ ResponsesRequest path. + + AI21 Jamba invoke with a messages array routes through ResponsesRequest and returns + the AI21 Chat Completions format: {"id": ..., "choices": [{"message": {"role": "assistant", + "content": "..."}, "finish_reason": "stop"}]}. + """ + print("\n=== Test 48: AI21 Jamba via invoke (messages path) ===") + + body = { + "messages": [{"role": "user", "content": "Say hello in one word."}], + "max_tokens": 50, + } + + try: + response = bedrock_client.invoke_model( + modelId="ai21.j2-mid-v1", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + except botocore.exceptions.ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code in ("ResourceNotFoundException", "ValidationException"): + pytest.skip(f"Titan text model no longer available: {e}") + raise + out = json.loads(response["body"].read()) + + # Must NOT be text-completion format + assert "results" not in out and "outputs" not in out, ( + f"Request was misrouted to text-completion path β€” got keys: {list(out.keys())}" + ) + # Must be AI21 Chat Completions format + choices = out.get("choices", []) + assert isinstance(choices, list) and len(choices) > 0, ( + f"Expected non-empty 'choices', got: {out}" + ) + msg = choices[0].get("message", {}) + assert msg.get("role") == "assistant", ( + f"Expected role='assistant', got: {msg.get('role')}" + ) + assert msg.get("content"), f"Expected non-empty content, got: {msg}" + assert choices[0].get("finish_reason") in ("stop", "length"), ( + f"Unexpected finish_reason: {choices[0].get('finish_reason')}" + ) + print(f" βœ“ finish_reason={choices[0]['finish_reason']!r}, content={msg['content'][:60]!r}") + + # ------------------------------------------------------------------ # + # 49. Anthropic claude β€” messages streaming via invoke-stream # + # ------------------------------------------------------------------ # + @skip_if_no_api_key("bedrock") + def test_49_invoke_stream_anthropic_messages(self, bedrock_client): + """Test Case 49: Anthropic Claude via invoke-with-response-stream with messages array. + + Verifies the ResponsesRequest streaming path: a 'messages' payload sent to + invoke-with-response-stream returns Anthropic SSE events (message_start, + content_block_delta, message_delta, message_stop) wrapped in InvokeModelRawChunk bytes. + """ + print("\n=== Test 49: Anthropic claude via invoke-with-response-stream (messages path) ===") + + body = { + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Say hello in one word."}], + } + ], + "max_tokens": 50, + } + + try: + response = bedrock_client.invoke_model_with_response_stream( + modelId="anthropic.claude-3-haiku-20240307-v1:0", + contentType="application/json", + accept="application/json", + body=json.dumps(body), + ) + except AttributeError: + pytest.skip("invoke_model_with_response_stream not available in this boto3 version") + except Exception as e: + pytest.fail(f"invoke_model_with_response_stream failed: {e}") + + stream = response.get("body") + if stream is None: + pytest.fail("Response missing 'body' stream") + + event_types = [] + text_parts = [] + start_time = time.time() + timeout = 30 + + for event in stream: + if time.time() - start_time > timeout: + pytest.fail(f"Streaming took longer than {timeout} seconds") + + if "chunk" not in event: + continue + raw_bytes = event["chunk"].get("bytes", b"") + if not raw_bytes: + continue + try: + chunk_json = json.loads(raw_bytes.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + continue + + event_type = chunk_json.get("type", "") + event_types.append(event_type) + + # Collect text deltas + if event_type == "content_block_delta": + delta = chunk_json.get("delta", {}) + if delta.get("type") == "text_delta": + text_parts.append(delta.get("text", "")) + + # Must have seen at least message_start and message_stop + assert "message_start" in event_types, ( + f"Expected 'message_start' event in stream, got event types: {event_types}" + ) + assert "message_stop" in event_types, ( + f"Expected 'message_stop' event in stream, got event types: {event_types}" + ) + full_text = "".join(text_parts) + assert full_text, f"Expected non-empty streamed text, got: {full_text!r}" + print(f" βœ“ event_types={event_types}, text={full_text[:60]!r}") diff --git a/tests/integrations/python/tests/test_openai.py b/tests/integrations/python/tests/test_openai.py index 01ca5b0212..cd01bc8a4f 100644 --- a/tests/integrations/python/tests/test_openai.py +++ b/tests/integrations/python/tests/test_openai.py @@ -71,6 +71,11 @@ 55. Image Generation - different sizes 60. WebSocket Responses API - base path 61. WebSocket Responses API - integration paths +62. Realtime WebSocket API - base path +63. Realtime WebSocket API - integration paths +64. Realtime client secret HTTP API - raw routes +65. Realtime client secret HTTP API - OpenAI constructor base_url compatibility +66. Realtime client secret HTTP API - unsupported provider Batch API uses OpenAI SDK with x-model-provider header to route to different providers. """ @@ -79,6 +84,7 @@ import os import time from typing import Any +from urllib.parse import quote import pytest from openai import OpenAI @@ -170,7 +176,11 @@ assert_valid_openai_annotation, # WebSocket utilities WS_RESPONSES_SIMPLE_INPUT, + get_realtime_test_model, get_ws_base_url, + run_openai_base_url_client_secret_request, + run_realtime_client_secret_request, + run_ws_realtime_test, run_ws_responses_test, ) from .utils.config_loader import get_config, get_model @@ -4287,3 +4297,252 @@ def test_61_ws_responses_integration_paths(self, test_config, provider, model, v f"Unexpected non-success terminal event at {path}. Events: {event_types}" ) assert len(result["content"]) > 0, f"Should receive non-empty content at {path}" + + @pytest.mark.parametrize( + "provider,_model,vk_enabled", + get_cross_provider_params_with_vk_for_scenario( + "simple_chat", include_providers=["openai"] + ), + ) + def test_62_ws_realtime_base_path(self, test_config, provider, _model, vk_enabled): + """Test Case 62: Realtime WebSocket API via base path /v1/realtime.""" + if provider == "_no_providers_": + pytest.skip("OpenAI provider is not configured for integration tests") + _ = test_config + + realtime_model = get_realtime_test_model(provider) + if not realtime_model: + pytest.skip("Realtime model is not configured for provider") + + ws_base = get_ws_base_url() + full_model = format_provider_model(provider, realtime_model) + ws_url = f"{ws_base}/v1/realtime?model={quote(full_model, safe='')}" + api_key = get_api_key(provider) + + extra_headers = {} + if vk_enabled: + config = get_config() + vk = config.get_virtual_key() + if vk: + extra_headers["x-bf-vk"] = vk + + result = run_ws_realtime_test( + ws_url=ws_url, + api_key=api_key, + timeout=45, + extra_headers=extra_headers if extra_headers else None, + ) + + assert result["error"] is None, f"Realtime websocket returned error: {result['error']}" + assert result["got_session_created"], "Expected session.created event" + assert result["got_session_updated"], "Expected session.updated event" + assert result["got_text_delta"], ( + f"Expected at least one response.output_text.delta event. " + f"Got {[e.get('type') for e in result['events']]}" + ) + assert result["got_response_done"], ( + f"Expected response.done event. Got {[e.get('type') for e in result['events']]}" + ) + + @pytest.mark.parametrize( + "provider,_model,vk_enabled", + get_cross_provider_params_with_vk_for_scenario( + "simple_chat", include_providers=["openai"] + ), + ) + def test_63_ws_realtime_integration_paths( + self, test_config, provider, _model, vk_enabled + ): + """Test Case 63: Realtime WebSocket API via OpenAI integration paths.""" + if provider == "_no_providers_": + pytest.skip("OpenAI provider is not configured for integration tests") + _ = test_config + + realtime_model = get_realtime_test_model(provider) + if not realtime_model: + pytest.skip("Realtime model is not configured for provider") + + ws_base = get_ws_base_url() + api_key = get_api_key(provider) + + extra_headers = {} + if vk_enabled: + config = get_config() + vk = config.get_virtual_key() + if vk: + extra_headers["x-bf-vk"] = vk + + integration_urls = [ + f"{ws_base}/openai/v1/realtime?model={quote(realtime_model, safe='')}", + f"{ws_base}/openai/realtime?deployment={quote(realtime_model, safe='')}", + f"{ws_base}/openai/openai/realtime?deployment={quote(realtime_model, safe='')}", + ] + + for ws_url in integration_urls: + result = run_ws_realtime_test( + ws_url=ws_url, + api_key=api_key, + timeout=45, + extra_headers=extra_headers if extra_headers else None, + ) + + assert result["error"] is None, f"Realtime websocket returned error at {ws_url}: {result['error']}" + assert result["got_session_created"], f"Expected session.created at {ws_url}" + assert result["got_session_updated"], f"Expected session.updated at {ws_url}" + assert result["got_text_delta"], f"Expected response.output_text.delta at {ws_url}" + assert result["got_response_done"], f"Expected response.done at {ws_url}" + + @pytest.mark.parametrize( + "provider,_model,vk_enabled", + get_cross_provider_params_with_vk_for_scenario( + "simple_chat", include_providers=["openai"] + ), + ) + def test_64_realtime_client_secret_routes(self, test_config, provider, _model, vk_enabled): + """Test Case 64: Realtime client secret creation via raw HTTP routes.""" + if provider == "_no_providers_": + pytest.skip("OpenAI provider is not configured for integration tests") + _ = test_config + + realtime_model = get_realtime_test_model(provider) + if not realtime_model: + pytest.skip("Realtime model is not configured for provider") + + config = get_config() + base_url = config._config["bifrost"]["base_url"].rstrip("/") + api_key = get_api_key(provider) + + extra_headers = {} + if vk_enabled: + vk = config.get_virtual_key() + if vk: + extra_headers["x-bf-vk"] = vk + + test_cases = [ + ( + f"{base_url}/v1/realtime/client_secrets", + {"session": {"model": format_provider_model(provider, realtime_model)}}, + "client_secrets", + ), + ( + f"{base_url}/v1/realtime/sessions", + {"model": format_provider_model(provider, realtime_model)}, + "sessions", + ), + ( + f"{base_url}/openai/v1/realtime/client_secrets", + {"session": {"model": realtime_model}}, + "client_secrets", + ), + ( + f"{base_url}/openai/v1/realtime/sessions", + {"model": realtime_model}, + "sessions", + ), + ] + + for url, payload, response_shape in test_cases: + result = run_realtime_client_secret_request( + url=url, + api_key=api_key, + request_body=payload, + extra_headers=extra_headers if extra_headers else None, + timeout=45, + ) + + assert result["status_code"] == 200, ( + f"Expected 200 from {url}, got {result['status_code']}: {result['body']}" + ) + assert "session" in result["body"], f"Missing session object in response from {url}" + if response_shape == "sessions": + assert "client_secret" in result["body"], f"Missing client_secret in response from {url}" + assert result["body"]["client_secret"].get("value"), f"Missing client_secret.value from {url}" + assert result["body"]["client_secret"].get("expires_at"), f"Missing client_secret.expires_at from {url}" + else: + assert result["body"].get("value"), f"Missing top-level value from {url}" + assert result["body"].get("expires_at"), f"Missing top-level expires_at from {url}" + + @pytest.mark.parametrize( + "provider,_model,vk_enabled", + get_cross_provider_params_with_vk_for_scenario( + "simple_chat", include_providers=["openai"] + ), + ) + def test_65_realtime_client_secret_openai_base_url_compatibility( + self, test_config, provider, _model, vk_enabled + ): + """Test Case 65: Realtime client secret creation works through OpenAI constructor base_url overrides.""" + if provider == "_no_providers_": + pytest.skip("OpenAI provider is not configured for integration tests") + _ = test_config + + realtime_model = get_realtime_test_model(provider) + if not realtime_model: + pytest.skip("Realtime model is not configured for provider") + + config = get_config() + api_key = get_api_key(provider) + root_base_url = config._config["bifrost"]["base_url"].rstrip("/") + openai_base_url = config.get_integration_url("openai") + + default_headers = {} + if vk_enabled: + vk = config.get_virtual_key() + if vk: + default_headers["x-bf-vk"] = vk + + test_cases = [ + ( + root_base_url, + {"session": {"model": format_provider_model(provider, realtime_model)}}, + ), + ( + openai_base_url, + {"session": {"model": realtime_model}}, + ), + ] + + for base_url, payload in test_cases: + result = run_openai_base_url_client_secret_request( + base_url=base_url, + api_key=api_key, + request_body=payload, + timeout=45, + default_headers=default_headers if default_headers else None, + ) + + assert result["status_code"] == 200, ( + f"Expected 200 from OpenAI client using base_url={base_url}, " + f"got {result['status_code']}: {result['body']}" + ) + assert result["body"].get("value"), ( + f"Missing top-level value for base_url={base_url}" + ) + assert result["body"].get("expires_at"), ( + f"Missing top-level expires_at for base_url={base_url}" + ) + + def test_66_realtime_client_secret_unsupported_provider(self, test_config): + """Test Case 66: Base realtime client secret route rejects unsupported providers.""" + _ = test_config + + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not configured") + + config = get_config() + base_url = config._config["bifrost"]["base_url"].rstrip("/") + api_key = get_api_key("openai") + + result = run_realtime_client_secret_request( + url=f"{base_url}/v1/realtime/client_secrets", + api_key=api_key, + request_body={"session": {"model": "anthropic/claude-sonnet-4-20250514"}}, + timeout=30, + ) + + assert result["status_code"] == 400, ( + f"Expected 400 for unsupported provider, got {result['status_code']}: {result['body']}" + ) + body = result["body"] + assert "error" in body, f"Expected error object in response, got {body}" + assert "not support" in body["error"]["message"].lower() or "provider" in body["error"]["message"].lower() diff --git a/tests/integrations/python/tests/utils/common.py b/tests/integrations/python/tests/utils/common.py index f75379377b..30ca7dc234 100644 --- a/tests/integrations/python/tests/utils/common.py +++ b/tests/integrations/python/tests/utils/common.py @@ -59,6 +59,31 @@ def _create_base64_image(width: int = 64, height: int = 64) -> str: return base64.b64encode(img_bytes).decode('utf-8') BASE64_IMAGE = _create_base64_image(64, 64) +BASE64_IMAGE_LARGE = _create_base64_image(512, 512) + + +def _create_titan_mask_image(width: int = 512, height: int = 512) -> str: + """Create a base64-encoded grayscale PNG mask for Titan inpainting/outpainting. + Titan requires mask pixel values to be exactly 0 (preserve) or 255 (edit area).""" + from PIL import Image, ImageDraw + import io + import base64 + + # Grayscale image, all black (preserve) by default + mask = Image.new('L', (width, height), 0) + + # White rectangle in center marks the area to edit + draw = ImageDraw.Draw(mask) + cx, cy = width // 2, height // 2 + w, h = width // 3, height // 3 + draw.rectangle([cx - w // 2, cy - h // 2, cx + w // 2, cy + h // 2], fill=255) + + buffer = io.BytesIO() + mask.save(buffer, format='PNG') + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +BASE64_TITAN_MASK_IMAGE = _create_titan_mask_image(512, 512) # Common Test Data SIMPLE_CHAT_MESSAGES = [{"role": "user", "content": "Hello! How are you today?"}] @@ -3317,13 +3342,15 @@ def run_ws_responses_test( content = "" error = None - start_time = time.monotonic() + deadline = time.monotonic() + timeout while True: - if time.monotonic() - start_time > timeout: + remaining = deadline - time.monotonic() + if remaining <= 0: raise TimeoutError( f"WebSocket stream did not reach terminal event within {timeout}s" ) + conn.settimeout(remaining) result = conn.recv() data = json.loads(result) events.append(data) @@ -3353,4 +3380,216 @@ def run_ws_responses_test( "error": error, } finally: - conn.close() \ No newline at end of file + conn.close() + + +def get_realtime_test_model(provider: str) -> str: + """Get a Realtime test model for the given provider.""" + env_var = f"{provider.upper()}_REALTIME_MODEL" + if provider == "openai": + return os.getenv(env_var, "gpt-realtime") + return os.getenv(env_var, "") + + +def run_ws_realtime_test( + ws_url, + api_key, + timeout=30, + extra_headers=None, +): + """Connect to a Realtime websocket endpoint and drive a text-only round trip.""" + import time + import websocket as ws_client + + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if extra_headers: + headers.update(extra_headers) + + header_list = [f"{k}: {v}" for k, v in headers.items()] + conn = ws_client.create_connection(ws_url, header=header_list, timeout=timeout) + + try: + events = [] + got_session_created = False + got_session_updated = False + got_text_delta = False + got_response_done = False + content = "" + error = None + + deadline = time.monotonic() + timeout + + def recv_event(): + remaining = deadline - time.monotonic() + if remaining <= 0: + raise TimeoutError( + f"Realtime websocket did not reach terminal state within {timeout}s" + ) + conn.settimeout(remaining) + raw = conn.recv() + data = json.loads(raw) + events.append(data) + return data + + while not got_session_created and error is None: + data = recv_event() + event_type = data.get("type", "") + if event_type == "session.created": + got_session_created = True + elif event_type == "error": + error = data + + if got_session_created and error is None: + conn.send( + json.dumps( + { + "type": "session.update", + "session": { + "type": "realtime", + "output_modalities": ["text"], + }, + } + ) + ) + + while True: + data = recv_event() + event_type = data.get("type", "") + if event_type == "session.updated": + got_session_updated = True + break + if event_type == "error": + error = data + break + + if got_session_updated and error is None: + conn.send( + json.dumps( + { + "type": "conversation.item.create", + "item": { + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Say hello in exactly two words.", + } + ], + }, + } + ) + ) + conn.send(json.dumps({"type": "response.create"})) + + while True: + data = recv_event() + event_type = data.get("type", "") + + if event_type == "response.output_text.delta": + delta = data.get("delta", "") + if isinstance(delta, dict): + content += delta.get("text", "") + elif isinstance(delta, str): + content += delta + got_text_delta = True + elif event_type == "response.done": + response_status = data.get("response", {}).get("status") + if response_status == "completed": + got_response_done = True + else: + error = data + break + elif event_type == "error": + error = data + break + + return { + "events": events, + "event_count": len(events), + "got_session_created": got_session_created, + "got_session_updated": got_session_updated, + "got_text_delta": got_text_delta, + "got_response_done": got_response_done, + "content": content, + "error": error, + } + finally: + conn.close() + + +def run_realtime_client_secret_request( + url, + api_key, + request_body, + extra_headers=None, + timeout=30, +): + """POST a realtime client-secret/session request and return status + body.""" + import requests + + headers = { + "Content-Type": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if extra_headers: + headers.update(extra_headers) + + response = requests.post(url, headers=headers, json=request_body, timeout=timeout) + + try: + body = response.json() + except ValueError: + body = {"raw_body": response.text} + + return { + "status_code": response.status_code, + "body": body, + "headers": dict(response.headers), + } + + +def run_openai_base_url_client_secret_request( + base_url, + api_key, + request_body, + timeout=30, + default_headers=None, +): + """Exercise the OpenAI client constructor base_url using the SDK's public request surface.""" + import httpx + from openai import OpenAI + + merged_headers = {} + if default_headers: + merged_headers.update(default_headers) + + client = OpenAI( + api_key=api_key, + base_url=base_url, + timeout=timeout, + default_headers=merged_headers, + ) + + try: + response = client.post( + "v1/realtime/client_secrets", + cast_to=httpx.Response, + body=request_body, + ) + + try: + body = response.json() + except ValueError: + body = {"raw_body": response.text} + + return { + "status_code": response.status_code, + "body": body, + "headers": dict(response.headers), + } + finally: + client.close() diff --git a/tests/integrations/typescript/config.json b/tests/integrations/typescript/config.json index 1764f50318..cf49dba281 100644 --- a/tests/integrations/typescript/config.json +++ b/tests/integrations/typescript/config.json @@ -7,6 +7,7 @@ "name": "OpenAI API Key", "value": "env.OPENAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -20,6 +21,7 @@ "name": "Anthropic API Key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -33,6 +35,7 @@ "name": "Gemini API Key", "value": "env.GEMINI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -44,11 +47,12 @@ "keys": [ { "name": "Vertex API Key", - "vertex_key_config": { + "vertex_key_config": { "project_id": "env.GOOGLE_PROJECT_ID", "region": "env.GOOGLE_LOCATION" }, - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -60,7 +64,8 @@ { "name": "Mistral API Key", "value": "env.MISTRAL_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -72,7 +77,8 @@ { "name": "Cohere API Key", "value": "env.COHERE_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -84,7 +90,8 @@ { "name": "Groq API Key", "value": "env.GROQ_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -96,7 +103,8 @@ { "name": "Perplexity API Key", "value": "env.PERPLEXITY_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -108,7 +116,8 @@ { "name": "Cerebras API Key", "value": "env.CEREBRAS_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -120,7 +129,8 @@ { "name": "OpenRouter API Key", "value": "env.OPENROUTER_API_KEY", - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -136,7 +146,8 @@ "endpoint": "env.AZURE_OPENAI_ENDPOINT", "api_version": "env.AZURE_OPENAI_API_VERSION" }, - "weight": 1 + "weight": 1, + "models": ["*"] } ], "network_config": { @@ -154,6 +165,7 @@ "arn": "env.AWS_ARN" }, "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -181,7 +193,21 @@ { "id": "vk-test", "value": "sk-bf-test-key", - "is_active": true + "is_active": true, + "provider_configs": [ + { "provider": "openai", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "anthropic", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "gemini", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "vertex", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "mistral", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "cohere", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "groq", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "perplexity", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "cerebras", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "openrouter", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "azure", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 }, + { "provider": "bedrock", "allowed_models": ["*"], "key_ids": ["*"], "weight": 1.0 } + ] } ] }, @@ -197,4 +223,4 @@ "max_request_body_size_mb": 100, "enable_litellm_fallbacks": false } -} \ No newline at end of file +} diff --git a/transports/bifrost-http/handlers/asyncinference.go b/transports/bifrost-http/handlers/asyncinference.go index 5d6d8a0626..010a74a110 100644 --- a/transports/bifrost-http/handlers/asyncinference.go +++ b/transports/bifrost-http/handlers/asyncinference.go @@ -108,7 +108,7 @@ func (h *AsyncHandler) asyncTextCompletion(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -146,7 +146,7 @@ func (h *AsyncHandler) asyncChatCompletion(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -184,7 +184,7 @@ func (h *AsyncHandler) asyncResponses(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -218,7 +218,7 @@ func (h *AsyncHandler) asyncEmbeddings(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -256,7 +256,7 @@ func (h *AsyncHandler) asyncSpeech(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -294,7 +294,7 @@ func (h *AsyncHandler) asyncTranscription(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -332,7 +332,7 @@ func (h *AsyncHandler) asyncImageGeneration(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -370,7 +370,7 @@ func (h *AsyncHandler) asyncImageEdit(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -403,7 +403,7 @@ func (h *AsyncHandler) asyncImageVariation(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -436,7 +436,7 @@ func (h *AsyncHandler) asyncRerank(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context") return @@ -473,7 +473,7 @@ func (h *AsyncHandler) getJob(operationType schemas.RequestType) fasthttp.Reques } // Get the requesting user's VK for auth check - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return diff --git a/transports/bifrost-http/handlers/config.go b/transports/bifrost-http/handlers/config.go index 3343308436..69314b839b 100644 --- a/transports/bifrost-http/handlers/config.go +++ b/transports/bifrost-http/handlers/config.go @@ -47,7 +47,7 @@ type ConfigManager interface { ReloadPricingManager(ctx context.Context) error ForceReloadPricing(ctx context.Context) error UpdateDropExcessRequests(ctx context.Context, value bool) - UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error ReloadPlugin(ctx context.Context, name string, path *string, pluginConfig any, placement *schemas.PluginPlacement, order *int) error RemovePlugin(ctx context.Context, name string) error ReloadProxyConfig(ctx context.Context, config *configstoreTables.GlobalProxyConfig) error @@ -276,9 +276,14 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { shouldReloadMCPToolManagerConfig = true } - // Only reload MCP tool manager config if MCP is configured + updatedConfig.MCPDisableAutoToolInject = payload.ClientConfig.MCPDisableAutoToolInject + if payload.ClientConfig.MCPDisableAutoToolInject != currentConfig.MCPDisableAutoToolInject { + shouldReloadMCPToolManagerConfig = true + } + + // Reload MCP tool manager config with all current values in one call if shouldReloadMCPToolManagerConfig && h.store.MCPConfig != nil { - if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel); err != nil { + if err := h.configManager.UpdateMCPToolManagerConfig(ctx, updatedConfig.MCPAgentDepth, updatedConfig.MCPToolExecutionTimeout, updatedConfig.MCPCodeModeBindingLevel, updatedConfig.MCPDisableAutoToolInject); err != nil { logger.Warn(fmt.Sprintf("failed to update mcp tool manager config: %v", err)) SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp tool manager config: %v", err)) return @@ -382,6 +387,11 @@ func (h *ConfigHandler) updateConfig(ctx *fasthttp.RequestCtx) { // Toggle whether deleted virtual keys should appear in logs filter data. updatedConfig.HideDeletedVirtualKeysInFilters = payload.ClientConfig.HideDeletedVirtualKeysInFilters + // No restart needed - routing engine reads via pointer, change is effective immediately. + if payload.ClientConfig.RoutingChainMaxDepth > 0 { + updatedConfig.RoutingChainMaxDepth = payload.ClientConfig.RoutingChainMaxDepth + } + // Handle HeaderFilterConfig changes if !headerFilterConfigEqual(payload.ClientConfig.HeaderFilterConfig, currentConfig.HeaderFilterConfig) { // Validate that no security headers are in the allowlist or denylist diff --git a/transports/bifrost-http/handlers/devpprof.go b/transports/bifrost-http/handlers/devpprof.go index fca8712d9b..19ef208237 100644 --- a/transports/bifrost-http/handlers/devpprof.go +++ b/transports/bifrost-http/handlers/devpprof.go @@ -531,7 +531,7 @@ func categorizeGoroutine(g *GoroutineGroup) { "PostMCPHook", "HTTPTransportPreHook", "HTTPTransportPostHook", - "completeAndFlushTrace", + "CompleteAndFlushTrace", "ProcessAndSend", "handleProvider", "Inject", // Observability plugin inject diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index 3a700c557f..66a95851a5 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -19,6 +19,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/framework/modelcatalog" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" @@ -40,6 +41,8 @@ type GovernanceManager interface { RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error ReloadRoutingRule(ctx context.Context, id string) error RemoveRoutingRule(ctx context.Context, id string) error + UpsertPricingOverride(ctx context.Context, override *configstoreTables.TablePricingOverride) error + DeletePricingOverride(ctx context.Context, id string) error } // GovernanceHandler manages HTTP requests for governance operations @@ -68,21 +71,22 @@ type CreateVirtualKeyRequest struct { Description string `json:"description,omitempty"` ProviderConfigs []struct { Provider string `json:"provider" validate:"required"` - Weight float64 `json:"weight,omitempty"` - AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed - Budget *CreateBudgetRequest `json:"budget,omitempty"` // Provider-level budget + Weight *float64 `json:"weight,omitempty"` + AllowedModels schemas.WhiteList `json:"allowed_models,omitempty"` // ["*"] allows all models; empty denies all + Budgets []CreateBudgetRequest `json:"budgets,omitempty"` // Multi-budget for provider config RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` // Provider-level rate limit - KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config - } `json:"provider_configs,omitempty"` // Empty means all providers allowed + KeyIDs schemas.WhiteList `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config + } `json:"provider_configs,omitempty"` // Empty means no providers allowed (deny-by-default) MCPConfigs []struct { - MCPClientName string `json:"mcp_client_name" validate:"required"` - ToolsToExecute []string `json:"tools_to_execute,omitempty"` - } `json:"mcp_configs,omitempty"` // Empty means all MCP clients allowed + MCPClientName string `json:"mcp_client_name" validate:"required"` + ToolsToExecute schemas.WhiteList `json:"tools_to_execute,omitempty"` + } `json:"mcp_configs,omitempty"` // Empty means no MCP clients allowed (deny-by-default) TeamID *string `json:"team_id,omitempty"` // Mutually exclusive with CustomerID CustomerID *string `json:"customer_id,omitempty"` // Mutually exclusive with TeamID - Budget *CreateBudgetRequest `json:"budget,omitempty"` - RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` - IsActive *bool `json:"is_active,omitempty"` + Budgets []CreateBudgetRequest `json:"budgets,omitempty"` // Multi-budget: each must have a unique reset_duration + RateLimit *CreateRateLimitRequest `json:"rate_limit,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + CalendarAligned bool `json:"calendar_aligned,omitempty"` // When true, all budgets reset at clean calendar boundaries } // UpdateVirtualKeyRequest represents the request body for updating a virtual key @@ -92,22 +96,23 @@ type UpdateVirtualKeyRequest struct { ProviderConfigs []struct { ID *uint `json:"id,omitempty"` // null for new entries Provider string `json:"provider" validate:"required"` - Weight float64 `json:"weight,omitempty"` - AllowedModels []string `json:"allowed_models,omitempty"` // Empty means all models allowed - Budget *UpdateBudgetRequest `json:"budget,omitempty"` // Provider-level budget + Weight *float64 `json:"weight,omitempty"` + AllowedModels schemas.WhiteList `json:"allowed_models,omitempty"` // ["*"] allows all models; empty denies all + Budgets []CreateBudgetRequest `json:"budgets,omitempty"` // Multi-budget for provider config RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` // Provider-level rate limit - KeyIDs []string `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config + KeyIDs schemas.WhiteList `json:"key_ids,omitempty"` // List of DBKey UUIDs to associate with this provider config } `json:"provider_configs,omitempty"` MCPConfigs []struct { - ID *uint `json:"id,omitempty"` // null for new entries - MCPClientName string `json:"mcp_client_name" validate:"required"` - ToolsToExecute []string `json:"tools_to_execute,omitempty"` + ID *uint `json:"id,omitempty"` // null for new entries + MCPClientName string `json:"mcp_client_name" validate:"required"` + ToolsToExecute schemas.WhiteList `json:"tools_to_execute,omitempty"` } `json:"mcp_configs,omitempty"` TeamID *string `json:"team_id,omitempty"` CustomerID *string `json:"customer_id,omitempty"` - Budget *UpdateBudgetRequest `json:"budget,omitempty"` - RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` - IsActive *bool `json:"is_active,omitempty"` + Budgets []CreateBudgetRequest `json:"budgets,omitempty"` // Multi-budget: replaces all VK-level budgets + RateLimit *UpdateRateLimitRequest `json:"rate_limit,omitempty"` + IsActive *bool `json:"is_active,omitempty"` + CalendarAligned *bool `json:"calendar_aligned,omitempty"` // When true, all budgets reset at clean calendar boundaries } // CreateBudgetRequest represents the request body for creating a budget @@ -138,7 +143,8 @@ type RoutingTarget struct { type CreateRoutingRuleRequest struct { Name string `json:"name" validate:"required"` Description string `json:"description,omitempty"` - Enabled *bool `json:"enabled,omitempty"` // nil = use DB default (true) + Enabled *bool `json:"enabled,omitempty"` // nil = use DB default (true) + ChainRule *bool `json:"chain_rule,omitempty"` // nil = use DB default (false) CelExpression string `json:"cel_expression"` Targets []RoutingTarget `json:"targets"` // Required; weights must sum to 1 Fallbacks []string `json:"fallbacks,omitempty"` @@ -153,6 +159,7 @@ type UpdateRoutingRuleRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Enabled *bool `json:"enabled,omitempty"` + ChainRule *bool `json:"chain_rule,omitempty"` CelExpression *string `json:"cel_expression,omitempty"` Targets []RoutingTarget `json:"targets,omitempty"` // If provided, replaces all existing targets; weights must sum to 1 Fallbacks []string `json:"fallbacks,omitempty"` @@ -202,8 +209,8 @@ func collectProviderConfigDeleteIDs( budgetIDs []string, rateLimitIDs []string, ) ([]string, []string) { - if config.BudgetID != nil { - budgetIDs = append(budgetIDs, *config.BudgetID) + for _, b := range config.Budgets { + budgetIDs = append(budgetIDs, b.ID) } if config.RateLimitID != nil { rateLimitIDs = append(rateLimitIDs, *config.RateLimitID) @@ -308,6 +315,12 @@ func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...sche r.GET("/api/governance/providers", lib.ChainMiddlewares(h.getProviderGovernance, middlewares...)) r.PUT("/api/governance/providers/{provider_name}", lib.ChainMiddlewares(h.updateProviderGovernance, middlewares...)) r.DELETE("/api/governance/providers/{provider_name}", lib.ChainMiddlewares(h.deleteProviderGovernance, middlewares...)) + + // Pricing override operations + r.GET("/api/governance/pricing-overrides", lib.ChainMiddlewares(h.getPricingOverrides, middlewares...)) + r.POST("/api/governance/pricing-overrides", lib.ChainMiddlewares(h.createPricingOverride, middlewares...)) + r.PUT("/api/governance/pricing-overrides/{id}", lib.ChainMiddlewares(h.updatePricingOverride, middlewares...)) + r.DELETE("/api/governance/pricing-overrides/{id}", lib.ChainMiddlewares(h.deletePricingOverride, middlewares...)) } // Virtual Key CRUD Operations @@ -434,16 +447,23 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { SendError(ctx, 400, "VirtualKey cannot be attached to both Team and Customer") return } - // Validate budget if provided - if req.Budget != nil { - if req.Budget.MaxLimit < 0 { - SendError(ctx, 400, fmt.Sprintf("Budget max_limit cannot be negative: %.2f", req.Budget.MaxLimit)) - return - } - // Validate reset duration format - if _, err := configstoreTables.ParseDuration(req.Budget.ResetDuration); err != nil { - SendError(ctx, 400, fmt.Sprintf("Invalid reset duration format: %s", req.Budget.ResetDuration)) - return + // Validate budgets if provided + if len(req.Budgets) > 0 { + seenDurations := make(map[string]bool) + for _, b := range req.Budgets { + if b.MaxLimit < 0 { + SendError(ctx, 400, fmt.Sprintf("Budget max_limit cannot be negative: %.2f", b.MaxLimit)) + return + } + if _, err := configstoreTables.ParseDuration(b.ResetDuration); err != nil { + SendError(ctx, 400, fmt.Sprintf("Invalid reset duration format: %s", b.ResetDuration)) + return + } + if seenDurations[b.ResetDuration] { + SendError(ctx, 400, fmt.Sprintf("Duplicate reset_duration in budgets: %s", b.ResetDuration)) + return + } + seenDurations[b.ResetDuration] = true } } // Set defaults @@ -454,30 +474,14 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { var vk configstoreTables.TableVirtualKey if err := h.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { vk = configstoreTables.TableVirtualKey{ - ID: uuid.NewString(), - Name: req.Name, - Value: governance.GenerateVirtualKey(), - Description: req.Description, - TeamID: req.TeamID, - CustomerID: req.CustomerID, - IsActive: isActive, - } - if req.Budget != nil { - budget := configstoreTables.TableBudget{ - ID: uuid.NewString(), - MaxLimit: req.Budget.MaxLimit, - ResetDuration: req.Budget.ResetDuration, - CalendarAligned: req.Budget.CalendarAligned, - LastReset: budgetLastReset(req.Budget.CalendarAligned, req.Budget.ResetDuration), - CurrentUsage: 0, - } - if err := validateBudget(&budget); err != nil { - return err - } - if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { - return err - } - vk.BudgetID = &budget.ID + ID: uuid.NewString(), + Name: req.Name, + Value: governance.GenerateVirtualKey(), + Description: req.Description, + TeamID: req.TeamID, + CustomerID: req.CustomerID, + IsActive: isActive, + CalendarAligned: req.CalendarAligned, } if req.RateLimit != nil { rateLimit := configstoreTables.TableRateLimit{ @@ -500,22 +504,40 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { if err := h.configStore.CreateVirtualKey(ctx, &vk, tx); err != nil { return err } + // Create multi-budgets for VK + if len(req.Budgets) > 0 { + for _, b := range req.Budgets { + budget := configstoreTables.TableBudget{ + ID: uuid.NewString(), + MaxLimit: b.MaxLimit, + ResetDuration: b.ResetDuration, + LastReset: budgetLastReset(vk.CalendarAligned, b.ResetDuration), + CurrentUsage: 0, + VirtualKeyID: &vk.ID, + } + if err := validateBudget(&budget); err != nil { + return err + } + if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { + return err + } + } + } if req.ProviderConfigs != nil { for _, pc := range req.ProviderConfigs { - // Validate budget if provided - if pc.Budget != nil { - if pc.Budget.MaxLimit < 0 { - return fmt.Errorf("provider config budget max_limit cannot be negative: %.2f", pc.Budget.MaxLimit) - } - // Validate reset duration format - if _, err := configstoreTables.ParseDuration(pc.Budget.ResetDuration); err != nil { - return fmt.Errorf("invalid provider config budget reset duration format: %s", pc.Budget.ResetDuration) - } + if err := pc.AllowedModels.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid allowed_models for provider %s: %w", pc.Provider, err)} + } + if err := pc.KeyIDs.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid key_ids for provider %s: %w", pc.Provider, err)} } // Get keys for this provider config if specified var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { + allowAllKeys := false + if pc.KeyIDs.IsUnrestricted() { + allowAllKeys = true + } else if !pc.KeyIDs.IsEmpty() { var err error keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) if err != nil { @@ -529,29 +551,12 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ VirtualKeyID: vk.ID, Provider: pc.Provider, - Weight: &pc.Weight, + Weight: pc.Weight, AllowedModels: pc.AllowedModels, + AllowAllKeys: allowAllKeys, Keys: keys, } - // Create budget for provider config if provided - if pc.Budget != nil { - budget := configstoreTables.TableBudget{ - ID: uuid.NewString(), - MaxLimit: pc.Budget.MaxLimit, - ResetDuration: pc.Budget.ResetDuration, - CalendarAligned: pc.Budget.CalendarAligned, - LastReset: budgetLastReset(pc.Budget.CalendarAligned, pc.Budget.ResetDuration), - CurrentUsage: 0, - } - if err := validateBudget(&budget); err != nil { - return err - } - if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { - return err - } - providerConfig.BudgetID = &budget.ID - } // Create rate limit for provider config if provided if pc.RateLimit != nil { rateLimit := configstoreTables.TableRateLimit{ @@ -575,6 +580,30 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { if err := h.configStore.CreateVirtualKeyProviderConfig(ctx, providerConfig, tx); err != nil { return err } + // Create multi-budgets for provider config + if len(pc.Budgets) > 0 { + seenDurations := make(map[string]bool) + for _, b := range pc.Budgets { + if seenDurations[b.ResetDuration] { + return &badRequestError{err: fmt.Errorf("duplicate reset_duration in provider config budgets: %s", b.ResetDuration)} + } + seenDurations[b.ResetDuration] = true + budget := configstoreTables.TableBudget{ + ID: uuid.NewString(), + MaxLimit: b.MaxLimit, + ResetDuration: b.ResetDuration, + LastReset: budgetLastReset(vk.CalendarAligned, b.ResetDuration), + CurrentUsage: 0, + ProviderConfigID: &providerConfig.ID, + } + if err := validateBudget(&budget); err != nil { + return err + } + if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { + return err + } + } + } } } if req.MCPConfigs != nil { @@ -582,12 +611,15 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { seenMCPClientNames := make(map[string]bool) for _, mc := range req.MCPConfigs { if seenMCPClientNames[mc.MCPClientName] { - return fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName) + return &badRequestError{err: fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName)} } seenMCPClientNames[mc.MCPClientName] = true } for _, mc := range req.MCPConfigs { + if err := mc.ToolsToExecute.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid tools_to_execute for mcp client %s: %w", mc.MCPClientName, err)} + } mcpClient, err := h.configStore.GetMCPClientByName(ctx, mc.MCPClientName) if err != nil { return fmt.Errorf("failed to get MCP client: %w", err) @@ -603,8 +635,8 @@ func (h *GovernanceHandler) createVirtualKey(ctx *fasthttp.RequestCtx) { } return nil }); err != nil { - // Check if this is a duplicate MCPClientName error and return 400 instead of 500 - if strings.Contains(err.Error(), "duplicate mcp_client_name:") { + var badReqErr *badRequestError + if errors.As(err, &badReqErr) { SendError(ctx, 400, err.Error()) return } @@ -683,8 +715,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return } if err := h.configStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { - var budgetIDToDelete, rateLimitIDToDelete string - var providerBudgetIDsToDelete, providerRateLimitIDsToDelete []string + var rateLimitIDToDelete string + var providerBudgetIDsToDelete []string + var providerRateLimitIDsToDelete []string // Update fields if provided if req.Name != nil { @@ -709,72 +742,77 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if req.IsActive != nil { vk.IsActive = *req.IsActive } - // Handle budget updates - if req.Budget != nil { - if isBudgetRemovalRequest(req.Budget) { - if vk.BudgetID != nil { - budgetIDToDelete = *vk.BudgetID - vk.BudgetID = nil - vk.Budget = nil - } - } else if vk.BudgetID != nil { - // Update existing budget - budget := configstoreTables.TableBudget{} - if err := tx.First(&budget, "id = ?", *vk.BudgetID).Error; err != nil { - return err - } - - if req.Budget.MaxLimit != nil { - budget.MaxLimit = *req.Budget.MaxLimit - } - if req.Budget.ResetDuration != nil { - budget.ResetDuration = *req.Budget.ResetDuration - } - if req.Budget.CalendarAligned != nil { - wasCalendarAligned := budget.CalendarAligned - budget.CalendarAligned = *req.Budget.CalendarAligned - if *req.Budget.CalendarAligned && !wasCalendarAligned { - budget.LastReset = configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, time.Now()) - budget.CurrentUsage = 0 - } - } - if err := validateBudget(&budget); err != nil { - return err - } - if err := h.configStore.UpdateBudget(ctx, &budget, tx); err != nil { - return err - } - vk.Budget = &budget - } else { - // Create new budget - if req.Budget.MaxLimit == nil || req.Budget.ResetDuration == nil { - return fmt.Errorf("both max_limit and reset_duration are required when creating a new budget") + if req.CalendarAligned != nil { + vk.CalendarAligned = *req.CalendarAligned + } + // Handle multi-budget updates + if req.Budgets != nil { + // Validate multi-budgets + seenDurations := make(map[string]bool) + for _, b := range req.Budgets { + if b.MaxLimit < 0 { + return &badRequestError{err: fmt.Errorf("budget max_limit cannot be negative: %.2f", b.MaxLimit)} } - if *req.Budget.MaxLimit < 0 { - return fmt.Errorf("budget max_limit cannot be negative: %.2f", *req.Budget.MaxLimit) + if _, err := configstoreTables.ParseDuration(b.ResetDuration); err != nil { + return &badRequestError{err: fmt.Errorf("invalid reset duration format: %s", b.ResetDuration)} } - if _, err := configstoreTables.ParseDuration(*req.Budget.ResetDuration); err != nil { - return fmt.Errorf("invalid reset duration format: %s", *req.Budget.ResetDuration) + if seenDurations[b.ResetDuration] { + return &badRequestError{err: fmt.Errorf("duplicate reset_duration in budgets: %s", b.ResetDuration)} } - calAligned := req.Budget.CalendarAligned != nil && *req.Budget.CalendarAligned - budget := configstoreTables.TableBudget{ - ID: uuid.NewString(), - MaxLimit: *req.Budget.MaxLimit, - ResetDuration: *req.Budget.ResetDuration, - CalendarAligned: calAligned, - LastReset: budgetLastReset(calAligned, *req.Budget.ResetDuration), - CurrentUsage: 0, + seenDurations[b.ResetDuration] = true } - if err := validateBudget(&budget); err != nil { - return err + + // Build map of existing budgets by reset_duration for matching + existingByDuration := make(map[string]configstoreTables.TableBudget) + for _, existing := range vk.Budgets { + existingByDuration[existing.ResetDuration] = existing } - if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { - return err + + // Reconcile: preserve existing budgets where possible, create new ones where needed + var reconciledBudgets []configstoreTables.TableBudget + matchedIDs := make(map[string]bool) + for _, b := range req.Budgets { + if existing, found := existingByDuration[b.ResetDuration]; found { + // Budget with same duration exists β€” update max_limit, preserve usage + existing.MaxLimit = b.MaxLimit + if err := validateBudget(&existing); err != nil { + return err + } + if err := h.configStore.UpdateBudget(ctx, &existing, tx); err != nil { + return err + } + reconciledBudgets = append(reconciledBudgets, existing) + matchedIDs[existing.ID] = true + } else { + // New budget duration β€” create fresh + budget := configstoreTables.TableBudget{ + ID: uuid.NewString(), + MaxLimit: b.MaxLimit, + ResetDuration: b.ResetDuration, + LastReset: budgetLastReset(vk.CalendarAligned, b.ResetDuration), + CurrentUsage: 0, + VirtualKeyID: &vk.ID, + } + if err := validateBudget(&budget); err != nil { + return err + } + if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { + return err + } + reconciledBudgets = append(reconciledBudgets, budget) + } } - vk.BudgetID = &budget.ID - vk.Budget = &budget + // Delete budgets that are no longer present + for _, existing := range vk.Budgets { + if !matchedIDs[existing.ID] { + if err := h.configStore.DeleteBudget(ctx, existing.ID, tx); err != nil { + return fmt.Errorf("failed to delete removed VK budget: %w", err) + } + } } + vk.Budgets = reconciledBudgets } + // Handle rate limit updates if req.RateLimit != nil { if isRateLimitRemovalRequest(req.RateLimit) { @@ -833,7 +871,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if req.ProviderConfigs != nil { // Get existing provider configs for comparison var existingConfigs []configstoreTables.TableVirtualKeyProviderConfig - if err := tx.Where("virtual_key_id = ?", vk.ID).Find(&existingConfigs).Error; err != nil { + if err := tx.Where("virtual_key_id = ?", vk.ID). + Preload("Budgets"). + Find(&existingConfigs).Error; err != nil { return err } // Create maps for easier lookup @@ -845,24 +885,19 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { // Process new configs: create new ones and update existing ones for _, pc := range req.ProviderConfigs { if pc.ID == nil { - // Validate budget if provided for new provider config - if pc.Budget != nil { - if pc.Budget.MaxLimit != nil && *pc.Budget.MaxLimit < 0 { - return fmt.Errorf("provider config budget max_limit cannot be negative: %.2f", *pc.Budget.MaxLimit) - } - if pc.Budget.ResetDuration != nil { - if _, err := configstoreTables.ParseDuration(*pc.Budget.ResetDuration); err != nil { - return fmt.Errorf("invalid provider config budget reset duration format: %s", *pc.Budget.ResetDuration) - } - } - // Both fields are required when creating new budget - if pc.Budget.MaxLimit == nil || pc.Budget.ResetDuration == nil { - return fmt.Errorf("both max_limit and reset_duration are required when creating a new provider budget") - } + if err := pc.AllowedModels.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid allowed_models for provider %s: %w", pc.Provider, err)} + } + if err := pc.KeyIDs.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid key_ids for provider %s: %w", pc.Provider, err)} } + // Get keys for this provider config if specified var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { + allowAllKeys := false + if pc.KeyIDs.IsUnrestricted() { + allowAllKeys = true + } else if !pc.KeyIDs.IsEmpty() { var err error keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) if err != nil { @@ -877,29 +912,11 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { providerConfig := &configstoreTables.TableVirtualKeyProviderConfig{ VirtualKeyID: vk.ID, Provider: pc.Provider, - Weight: &pc.Weight, + Weight: pc.Weight, AllowedModels: pc.AllowedModels, + AllowAllKeys: allowAllKeys, Keys: keys, } - // Create budget for provider config if provided - if pc.Budget != nil { - pcCalAligned := pc.Budget.CalendarAligned != nil && *pc.Budget.CalendarAligned - budget := configstoreTables.TableBudget{ - ID: uuid.NewString(), - MaxLimit: *pc.Budget.MaxLimit, - ResetDuration: *pc.Budget.ResetDuration, - CalendarAligned: pcCalAligned, - LastReset: budgetLastReset(pcCalAligned, *pc.Budget.ResetDuration), - CurrentUsage: 0, - } - if err := validateBudget(&budget); err != nil { - return err - } - if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { - return err - } - providerConfig.BudgetID = &budget.ID - } // Create rate limit for provider config if provided if pc.RateLimit != nil { rateLimit := configstoreTables.TableRateLimit{ @@ -922,6 +939,30 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { if err := h.configStore.CreateVirtualKeyProviderConfig(ctx, providerConfig, tx); err != nil { return err } + // Create multi-budgets for new provider config in update + if len(pc.Budgets) > 0 { + seenDurations := make(map[string]bool) + for _, b := range pc.Budgets { + if seenDurations[b.ResetDuration] { + return &badRequestError{err: fmt.Errorf("duplicate reset_duration in provider config budgets: %s", b.ResetDuration)} + } + seenDurations[b.ResetDuration] = true + budget := configstoreTables.TableBudget{ + ID: uuid.NewString(), + MaxLimit: b.MaxLimit, + ResetDuration: b.ResetDuration, + LastReset: budgetLastReset(vk.CalendarAligned, b.ResetDuration), + CurrentUsage: 0, + ProviderConfigID: &providerConfig.ID, + } + if err := validateBudget(&budget); err != nil { + return err + } + if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { + return err + } + } + } } else { // Update existing provider config existing, ok := existingConfigsMap[*pc.ID] @@ -929,13 +970,22 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return fmt.Errorf("provider config %d does not belong to this virtual key", *pc.ID) } requestConfigsMap[*pc.ID] = true + if err := pc.AllowedModels.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid allowed_models for provider %s: %w", pc.Provider, err)} + } + if err := pc.KeyIDs.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid key_ids for provider %s: %w", pc.Provider, err)} + } existing.Provider = pc.Provider - existing.Weight = &pc.Weight + existing.Weight = pc.Weight existing.AllowedModels = pc.AllowedModels // Get keys for this provider config if specified var keys []configstoreTables.TableKey - if len(pc.KeyIDs) > 0 { + allowAllKeys := false + if pc.KeyIDs.IsUnrestricted() { + allowAllKeys = true + } else if !pc.KeyIDs.IsEmpty() { var err error keys, err = h.configStore.GetKeysByIDs(ctx, pc.KeyIDs) if err != nil { @@ -945,70 +995,75 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return fmt.Errorf("some keys not found for provider %s: expected %d, found %d", pc.Provider, len(pc.KeyIDs), len(keys)) } } + existing.AllowAllKeys = allowAllKeys existing.Keys = keys - // Handle budget updates for provider config - if pc.Budget != nil { - if isBudgetRemovalRequest(pc.Budget) { - if existing.BudgetID != nil { - providerBudgetIDsToDelete = append(providerBudgetIDsToDelete, *existing.BudgetID) - existing.BudgetID = nil - existing.Budget = nil + // Handle multi-budget updates for existing provider config + if pc.Budgets != nil { + // Validate + seenDurations := make(map[string]bool) + for _, b := range pc.Budgets { + if b.MaxLimit < 0 { + return &badRequestError{err: fmt.Errorf("provider config budget max_limit cannot be negative: %.2f", b.MaxLimit)} } - } else if existing.BudgetID != nil { - // Update existing budget - budget := configstoreTables.TableBudget{} - if err := tx.First(&budget, "id = ?", *existing.BudgetID).Error; err != nil { - return err - } - if pc.Budget.MaxLimit != nil { - budget.MaxLimit = *pc.Budget.MaxLimit - } - if pc.Budget.ResetDuration != nil { - budget.ResetDuration = *pc.Budget.ResetDuration - } - if pc.Budget.CalendarAligned != nil { - wasCalendarAligned := budget.CalendarAligned - budget.CalendarAligned = *pc.Budget.CalendarAligned - if *pc.Budget.CalendarAligned && !wasCalendarAligned { - budget.LastReset = configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, time.Now()) - budget.CurrentUsage = 0 + if _, err := configstoreTables.ParseDuration(b.ResetDuration); err != nil { + return &badRequestError{err: fmt.Errorf("invalid provider config budget reset duration format: %s", b.ResetDuration)} } - } - if err := validateBudget(&budget); err != nil { - return err - } - if err := h.configStore.UpdateBudget(ctx, &budget, tx); err != nil { - return err - } - } else { - // Create new budget for existing provider config - if pc.Budget.MaxLimit == nil || pc.Budget.ResetDuration == nil { - return fmt.Errorf("both max_limit and reset_duration are required when creating a new provider budget") + if seenDurations[b.ResetDuration] { + return &badRequestError{err: fmt.Errorf("duplicate reset_duration in provider config budgets: %s", b.ResetDuration)} } - if *pc.Budget.MaxLimit < 0 { - return fmt.Errorf("provider config budget max_limit cannot be negative: %.2f", *pc.Budget.MaxLimit) - } - if _, err := configstoreTables.ParseDuration(*pc.Budget.ResetDuration); err != nil { - return fmt.Errorf("invalid provider config budget reset duration format: %s", *pc.Budget.ResetDuration) - } - pcExistCalAligned := pc.Budget.CalendarAligned != nil && *pc.Budget.CalendarAligned - budget := configstoreTables.TableBudget{ - ID: uuid.NewString(), - MaxLimit: *pc.Budget.MaxLimit, - ResetDuration: *pc.Budget.ResetDuration, - CalendarAligned: pcExistCalAligned, - LastReset: budgetLastReset(pcExistCalAligned, *pc.Budget.ResetDuration), - CurrentUsage: 0, + seenDurations[b.ResetDuration] = true } - if err := validateBudget(&budget); err != nil { - return err + + // Build map of existing budgets by reset_duration for matching + pcExistingByDuration := make(map[string]configstoreTables.TableBudget) + for _, eb := range existing.Budgets { + pcExistingByDuration[eb.ResetDuration] = eb } - if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { - return err + + // Reconcile: preserve existing budgets where possible + var pcReconciledBudgets []configstoreTables.TableBudget + pcMatchedIDs := make(map[string]bool) + for _, b := range pc.Budgets { + if eb, found := pcExistingByDuration[b.ResetDuration]; found { + // Budget with same duration exists β€” update max_limit, preserve usage + eb.MaxLimit = b.MaxLimit + if err := validateBudget(&eb); err != nil { + return err + } + if err := h.configStore.UpdateBudget(ctx, &eb, tx); err != nil { + return err + } + pcReconciledBudgets = append(pcReconciledBudgets, eb) + pcMatchedIDs[eb.ID] = true + } else { + // New budget duration β€” create fresh + budget := configstoreTables.TableBudget{ + ID: uuid.NewString(), + MaxLimit: b.MaxLimit, + ResetDuration: b.ResetDuration, + LastReset: budgetLastReset(vk.CalendarAligned, b.ResetDuration), + CurrentUsage: 0, + ProviderConfigID: &existing.ID, + } + if err := validateBudget(&budget); err != nil { + return err + } + if err := h.configStore.CreateBudget(ctx, &budget, tx); err != nil { + return err + } + pcReconciledBudgets = append(pcReconciledBudgets, budget) + } } - existing.BudgetID = &budget.ID + // Delete budgets that are no longer present + for _, eb := range existing.Budgets { + if !pcMatchedIDs[eb.ID] { + if err := h.configStore.DeleteBudget(ctx, eb.ID, tx); err != nil { + return fmt.Errorf("failed to delete removed provider config budget: %w", err) + } + } } + existing.Budgets = pcReconciledBudgets } // Handle rate limit updates for provider config if pc.RateLimit != nil { @@ -1083,7 +1138,7 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { seenMCPClientNames := make(map[string]bool) for _, mc := range req.MCPConfigs { if seenMCPClientNames[mc.MCPClientName] { - return fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName) + return &badRequestError{err: fmt.Errorf("duplicate mcp_client_name: %s", mc.MCPClientName)} } seenMCPClientNames[mc.MCPClientName] = true } @@ -1100,6 +1155,9 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { requestMCPConfigsMap := make(map[uint]bool) // Process new configs: create new ones and update existing ones for _, mc := range req.MCPConfigs { + if err := mc.ToolsToExecute.Validate(); err != nil { + return &badRequestError{err: fmt.Errorf("invalid tools_to_execute for mcp client %s: %w", mc.MCPClientName, err)} + } if mc.ID == nil { mcpClient, err := h.configStore.GetMCPClientByName(ctx, mc.MCPClientName) if err != nil { @@ -1136,11 +1194,6 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { } } - if budgetIDToDelete != "" { - if err := tx.Delete(&configstoreTables.TableBudget{}, "id = ?", budgetIDToDelete).Error; err != nil { - return err - } - } if rateLimitIDToDelete != "" { if err := tx.Delete(&configstoreTables.TableRateLimit{}, "id = ?", rateLimitIDToDelete).Error; err != nil { return err @@ -1159,11 +1212,10 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { return nil }); err != nil { - errMsg := err.Error() - // Check if this is a duplicate MCPClientName error and return 400 instead of 500 - if strings.Contains(errMsg, "duplicate mcp_client_name:") || - strings.Contains(errMsg, "already exists'") || - strings.Contains(errMsg, "duplicate key") { + var badReqErr *badRequestError + if errors.As(err, &badReqErr) || + strings.Contains(err.Error(), "already exists") || + strings.Contains(err.Error(), "duplicate key") { SendError(ctx, 400, fmt.Sprintf("Failed to update virtual key: %v", err)) return } @@ -1176,7 +1228,12 @@ func (h *GovernanceHandler) updateVirtualKey(ctx *fasthttp.RequestCtx) { logger.Error("failed to load relationships for updated VK: %v", err) preloadedVk = vk } - h.governanceManager.ReloadVirtualKey(ctx, vk.ID) + if _, err := h.governanceManager.ReloadVirtualKey(ctx, vk.ID); err != nil { + // Should never happen but just in case + logger.Error("failed to reload virtual key after update: %v", err) + SendError(ctx, 500, "Virtual key updated in database but failed to reload in-memory state") + return + } SendJSON(ctx, map[string]interface{}{ "message": "Virtual key updated successfully", "virtual_key": preloadedVk, @@ -1352,8 +1409,7 @@ func (h *GovernanceHandler) createTeam(ctx *fasthttp.RequestCtx) { ID: uuid.NewString(), MaxLimit: req.Budget.MaxLimit, ResetDuration: req.Budget.ResetDuration, - CalendarAligned: req.Budget.CalendarAligned, - LastReset: budgetLastReset(req.Budget.CalendarAligned, req.Budget.ResetDuration), + LastReset: budgetLastReset(false, req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -1493,14 +1549,6 @@ func (h *GovernanceHandler) updateTeam(ctx *fasthttp.RequestCtx) { if req.Budget.ResetDuration != nil { budget.ResetDuration = *req.Budget.ResetDuration } - if req.Budget.CalendarAligned != nil { - wasCalendarAligned := budget.CalendarAligned - budget.CalendarAligned = *req.Budget.CalendarAligned - if *req.Budget.CalendarAligned && !wasCalendarAligned { - budget.LastReset = configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, time.Now()) - budget.CurrentUsage = 0 - } - } if err := validateBudget(&budget); err != nil { return err } @@ -1519,13 +1567,11 @@ func (h *GovernanceHandler) updateTeam(ctx *fasthttp.RequestCtx) { if _, err := configstoreTables.ParseDuration(*req.Budget.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", *req.Budget.ResetDuration) } - teamCalAligned := req.Budget.CalendarAligned != nil && *req.Budget.CalendarAligned budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: *req.Budget.MaxLimit, ResetDuration: *req.Budget.ResetDuration, - CalendarAligned: teamCalAligned, - LastReset: budgetLastReset(teamCalAligned, *req.Budget.ResetDuration), + LastReset: budgetLastReset(false, *req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -1764,8 +1810,7 @@ func (h *GovernanceHandler) createCustomer(ctx *fasthttp.RequestCtx) { ID: uuid.NewString(), MaxLimit: req.Budget.MaxLimit, ResetDuration: req.Budget.ResetDuration, - CalendarAligned: req.Budget.CalendarAligned, - LastReset: budgetLastReset(req.Budget.CalendarAligned, req.Budget.ResetDuration), + LastReset: budgetLastReset(false, req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -1895,14 +1940,6 @@ func (h *GovernanceHandler) updateCustomer(ctx *fasthttp.RequestCtx) { if req.Budget.ResetDuration != nil { budget.ResetDuration = *req.Budget.ResetDuration } - if req.Budget.CalendarAligned != nil { - wasCalendarAligned := budget.CalendarAligned - budget.CalendarAligned = *req.Budget.CalendarAligned - if *req.Budget.CalendarAligned && !wasCalendarAligned { - budget.LastReset = configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, time.Now()) - budget.CurrentUsage = 0 - } - } if err := validateBudget(&budget); err != nil { return err } @@ -1921,13 +1958,11 @@ func (h *GovernanceHandler) updateCustomer(ctx *fasthttp.RequestCtx) { if _, err := configstoreTables.ParseDuration(*req.Budget.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", *req.Budget.ResetDuration) } - custCalAligned := req.Budget.CalendarAligned != nil && *req.Budget.CalendarAligned budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: *req.Budget.MaxLimit, ResetDuration: *req.Budget.ResetDuration, - CalendarAligned: custCalAligned, - LastReset: budgetLastReset(custCalAligned, *req.Budget.ResetDuration), + LastReset: budgetLastReset(false, *req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -2152,9 +2187,6 @@ func validateBudget(budget *configstoreTables.TableBudget) error { if _, err := configstoreTables.ParseDuration(budget.ResetDuration); err != nil { return fmt.Errorf("invalid budget reset duration format: %s", budget.ResetDuration) } - if budget.CalendarAligned && !configstoreTables.IsCalendarAlignableDuration(budget.ResetDuration) { - return fmt.Errorf("calendar_aligned is not supported for reset duration %q: only daily (d), weekly (w), monthly (M), and yearly (Y) periods support calendar alignment", budget.ResetDuration) - } return nil } @@ -2317,8 +2349,7 @@ func (h *GovernanceHandler) createModelConfig(ctx *fasthttp.RequestCtx) { ID: uuid.NewString(), MaxLimit: req.Budget.MaxLimit, ResetDuration: req.Budget.ResetDuration, - CalendarAligned: req.Budget.CalendarAligned, - LastReset: budgetLastReset(req.Budget.CalendarAligned, req.Budget.ResetDuration), + LastReset: budgetLastReset(false, req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -2423,14 +2454,6 @@ func (h *GovernanceHandler) updateModelConfig(ctx *fasthttp.RequestCtx) { if req.Budget.ResetDuration != nil { budget.ResetDuration = *req.Budget.ResetDuration } - if req.Budget.CalendarAligned != nil { - wasCalendarAligned := budget.CalendarAligned - budget.CalendarAligned = *req.Budget.CalendarAligned - if *req.Budget.CalendarAligned && !wasCalendarAligned { - budget.LastReset = configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, time.Now()) - budget.CurrentUsage = 0 - } - } if err := validateBudget(&budget); err != nil { return err } @@ -2449,13 +2472,11 @@ func (h *GovernanceHandler) updateModelConfig(ctx *fasthttp.RequestCtx) { if _, err := configstoreTables.ParseDuration(*req.Budget.ResetDuration); err != nil { return fmt.Errorf("invalid reset duration format: %s", *req.Budget.ResetDuration) } - mcCalAligned := req.Budget.CalendarAligned != nil && *req.Budget.CalendarAligned budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: *req.Budget.MaxLimit, ResetDuration: *req.Budget.ResetDuration, - CalendarAligned: mcCalAligned, - LastReset: budgetLastReset(mcCalAligned, *req.Budget.ResetDuration), + LastReset: budgetLastReset(false, *req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -2695,14 +2716,6 @@ func (h *GovernanceHandler) updateProviderGovernance(ctx *fasthttp.RequestCtx) { if req.Budget.ResetDuration != nil { budget.ResetDuration = *req.Budget.ResetDuration } - if req.Budget.CalendarAligned != nil { - wasCalendarAligned := budget.CalendarAligned - budget.CalendarAligned = *req.Budget.CalendarAligned - if *req.Budget.CalendarAligned && !wasCalendarAligned { - budget.LastReset = configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, time.Now()) - budget.CurrentUsage = 0 - } - } if err := validateBudget(&budget); err != nil { return err } @@ -2715,13 +2728,11 @@ func (h *GovernanceHandler) updateProviderGovernance(ctx *fasthttp.RequestCtx) { if req.Budget.MaxLimit == nil || req.Budget.ResetDuration == nil { return fmt.Errorf("both max_limit and reset_duration are required when creating a new budget") } - provCalAligned := req.Budget.CalendarAligned != nil && *req.Budget.CalendarAligned budget := configstoreTables.TableBudget{ ID: uuid.NewString(), MaxLimit: *req.Budget.MaxLimit, ResetDuration: *req.Budget.ResetDuration, - CalendarAligned: provCalAligned, - LastReset: budgetLastReset(provCalAligned, *req.Budget.ResetDuration), + LastReset: budgetLastReset(false, *req.Budget.ResetDuration), CurrentUsage: 0, } if err := validateBudget(&budget); err != nil { @@ -3132,16 +3143,21 @@ func (h *GovernanceHandler) createRoutingRule(ctx *fasthttp.RequestCtx) { } // Create routing rule - // Handle Enabled: nil means use DB default (true), otherwise use provided value + // Handle Enabled/ChainRule: nil means use DB default (true/false), otherwise use provided value enabled := true // DB default if req.Enabled != nil { enabled = *req.Enabled } + chainRule := false // DB default + if req.ChainRule != nil { + chainRule = *req.ChainRule + } rule := &configstoreTables.TableRoutingRule{ ID: ruleID, Name: req.Name, Description: req.Description, Enabled: enabled, + ChainRule: chainRule, CelExpression: req.CelExpression, Targets: targets, Scope: scope, @@ -3201,6 +3217,9 @@ func (h *GovernanceHandler) updateRoutingRule(ctx *fasthttp.RequestCtx) { if req.Enabled != nil { rule.Enabled = *req.Enabled } + if req.ChainRule != nil { + rule.ChainRule = *req.ChainRule + } if req.CelExpression != nil { rule.CelExpression = *req.CelExpression } @@ -3295,6 +3314,376 @@ func (h *GovernanceHandler) deleteRoutingRule(ctx *fasthttp.RequestCtx) { }) } +// --------------------------------------------------------------------------- +// Pricing Override Operations +// --------------------------------------------------------------------------- + +// CreatePricingOverrideRequest is the request payload for creating a governance +// pricing override. +type CreatePricingOverrideRequest struct { + Name string `json:"name"` + ScopeKind modelcatalog.ScopeKind `json:"scope_kind"` + VirtualKeyID *string `json:"virtual_key_id,omitempty"` + ProviderID *string `json:"provider_id,omitempty"` + ProviderKeyID *string `json:"provider_key_id,omitempty"` + MatchType modelcatalog.MatchType `json:"match_type"` + Pattern string `json:"pattern"` + RequestTypes []schemas.RequestType `json:"request_types,omitempty"` + Patch modelcatalog.PricingOptions `json:"patch,omitempty"` +} + +// nullableString tracks whether a JSON string field was explicitly present in +// the request body (even as null), so the merge logic can distinguish "omitted" +// (leave existing value) from "set to null" (clear the value). +type nullableString struct { + Value *string + Set bool +} + +func (n *nullableString) UnmarshalJSON(b []byte) error { + n.Set = true + if string(b) == "null" { + n.Value = nil + return nil + } + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + n.Value = &s + return nil +} + +// UpdatePricingOverrideRequest is the request payload for updating a governance +// pricing override. All fields except Patch are optional β€” omitted fields are +// merged from the existing record. Patch is always replaced in full. +type UpdatePricingOverrideRequest struct { + Name *string `json:"name,omitempty"` + ScopeKind *modelcatalog.ScopeKind `json:"scope_kind,omitempty"` + VirtualKeyID nullableString `json:"virtual_key_id"` + ProviderID nullableString `json:"provider_id"` + ProviderKeyID nullableString `json:"provider_key_id"` + MatchType *modelcatalog.MatchType `json:"match_type,omitempty"` + Pattern *string `json:"pattern,omitempty"` + RequestTypes []schemas.RequestType `json:"request_types,omitempty"` + Patch *modelcatalog.PricingOptions `json:"patch,omitempty"` +} + +func (h *GovernanceHandler) getPricingOverrides(ctx *fasthttp.RequestCtx) { + // Parse filter parameters + var scopeKind, virtualKeyID, providerID, providerKeyID *string + if v := strings.TrimSpace(string(ctx.QueryArgs().Peek("scope_kind"))); v != "" { + scopeKind = &v + } + if v := strings.TrimSpace(string(ctx.QueryArgs().Peek("virtual_key_id"))); v != "" { + virtualKeyID = &v + } + if v := strings.TrimSpace(string(ctx.QueryArgs().Peek("provider_id"))); v != "" { + providerID = &v + } + if v := strings.TrimSpace(string(ctx.QueryArgs().Peek("provider_key_id"))); v != "" { + providerKeyID = &v + } + + // Check for pagination parameters + limitStr := string(ctx.QueryArgs().Peek("limit")) + offsetStr := string(ctx.QueryArgs().Peek("offset")) + search := string(ctx.QueryArgs().Peek("search")) + + if limitStr != "" || offsetStr != "" || search != "" { + params := configstore.PricingOverridesQueryParams{ + Search: search, + ScopeKind: scopeKind, + VirtualKeyID: virtualKeyID, + ProviderID: providerID, + ProviderKeyID: providerKeyID, + } + if limitStr != "" { + n, err := strconv.Atoi(limitStr) + if err != nil { + SendError(ctx, 400, "Invalid limit parameter: must be a number") + return + } + if n < 0 { + SendError(ctx, 400, "Invalid limit parameter: must be non-negative") + return + } + params.Limit = n + } + if offsetStr != "" { + n, err := strconv.Atoi(offsetStr) + if err != nil { + SendError(ctx, 400, "Invalid offset parameter: must be a number") + return + } + if n < 0 { + SendError(ctx, 400, "Invalid offset parameter: must be non-negative") + return + } + params.Offset = n + } + + params.Limit, params.Offset = ClampPaginationParams(params.Limit, params.Offset) + overrides, totalCount, err := h.configStore.GetPricingOverridesPaginated(ctx, params) + if err != nil { + logger.Error("failed to retrieve pricing overrides: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to retrieve pricing overrides") + return + } + SendJSON(ctx, map[string]interface{}{ + "pricing_overrides": overrides, + "count": len(overrides), + "total_count": totalCount, + "limit": params.Limit, + "offset": params.Offset, + }) + return + } + + // Non-paginated path: return all matching overrides (backward compatible) + filters := configstore.PricingOverrideFilters{ + ScopeKind: scopeKind, + VirtualKeyID: virtualKeyID, + ProviderID: providerID, + ProviderKeyID: providerKeyID, + } + overrides, err := h.configStore.GetPricingOverrides(ctx, filters) + if err != nil { + logger.Error("failed to retrieve pricing overrides: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to retrieve pricing overrides") + return + } + + SendJSON(ctx, map[string]interface{}{ + "pricing_overrides": overrides, + "count": len(overrides), + "total_count": len(overrides), + "limit": len(overrides), + "offset": 0, + }) +} + +func (h *GovernanceHandler) createPricingOverride(ctx *fasthttp.RequestCtx) { + var req CreatePricingOverrideRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid JSON") + return + } + + name, err := normalizeAndValidatePricingOverrideName(req.Name) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + shape := modelcatalog.PricingOverride{ + ScopeKind: req.ScopeKind, + VirtualKeyID: req.VirtualKeyID, + ProviderID: req.ProviderID, + ProviderKeyID: req.ProviderKeyID, + MatchType: req.MatchType, + Pattern: req.Pattern, + RequestTypes: req.RequestTypes, + } + if err := shape.IsValid(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + patchJSON, err := sonic.Marshal(req.Patch) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid patch") + return + } + + now := time.Now() + override := configstoreTables.TablePricingOverride{ + ID: uuid.NewString(), + Name: name, + ScopeKind: string(req.ScopeKind), + VirtualKeyID: normalizeOptionalString(req.VirtualKeyID), + ProviderID: normalizeOptionalString(req.ProviderID), + ProviderKeyID: normalizeOptionalString(req.ProviderKeyID), + MatchType: string(req.MatchType), + Pattern: strings.TrimSpace(req.Pattern), + RequestTypes: req.RequestTypes, + PricingPatchJSON: string(patchJSON), + ConfigHash: "", + CreatedAt: now, + UpdatedAt: now, + } + + if err := h.configStore.CreatePricingOverride(ctx, &override); err != nil { + logger.Error("failed to create pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to create pricing override") + return + } + + if err := h.governanceManager.UpsertPricingOverride(ctx, &override); err != nil { + logger.Error("failed to upsert pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to upsert pricing override") + return + } + SendJSONWithStatus(ctx, map[string]interface{}{ + "message": "Pricing override created successfully", + "pricing_override": override, + }, fasthttp.StatusCreated) +} + +func (h *GovernanceHandler) updatePricingOverride(ctx *fasthttp.RequestCtx) { + id := ctx.UserValue("id").(string) + + var req UpdatePricingOverrideRequest + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid JSON") + return + } + + existing, err := h.configStore.GetPricingOverrideByID(ctx, id) + if err != nil { + if errors.Is(err, configstore.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, "Pricing override not found") + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to retrieve pricing override: %v", err)) + return + } + + // Merge request fields onto the existing record; omitted fields keep their current values. + merged := modelcatalog.PricingOverride{ + ScopeKind: modelcatalog.ScopeKind(existing.ScopeKind), + VirtualKeyID: existing.VirtualKeyID, + ProviderID: existing.ProviderID, + ProviderKeyID: existing.ProviderKeyID, + MatchType: modelcatalog.MatchType(existing.MatchType), + Pattern: existing.Pattern, + RequestTypes: existing.RequestTypes, + } + if req.ScopeKind != nil { + merged.ScopeKind = *req.ScopeKind + // Changing scope_kind resets all scope IDs; only what the request + // explicitly provides will be kept. + merged.VirtualKeyID = nil + merged.ProviderID = nil + merged.ProviderKeyID = nil + } + if req.VirtualKeyID.Set { + merged.VirtualKeyID = req.VirtualKeyID.Value + } + if req.ProviderID.Set { + merged.ProviderID = req.ProviderID.Value + } + if req.ProviderKeyID.Set { + merged.ProviderKeyID = req.ProviderKeyID.Value + } + if req.MatchType != nil { + merged.MatchType = *req.MatchType + } + if req.Pattern != nil { + merged.Pattern = *req.Pattern + } + if req.RequestTypes != nil { + merged.RequestTypes = req.RequestTypes + } + + if err := merged.IsValid(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + // Resolve name: use provided value or fall back to existing. + nameStr := existing.Name + if req.Name != nil { + nameStr, err = normalizeAndValidatePricingOverrideName(*req.Name) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + } + + // Patch JSON: always replace in full with whatever is provided (or keep existing if omitted). + pricingPatchJSON := existing.PricingPatchJSON + if req.Patch != nil { + b, err := sonic.Marshal(req.Patch) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid patch") + return + } + pricingPatchJSON = string(b) + } + + override := configstoreTables.TablePricingOverride{ + ID: id, + Name: nameStr, + ScopeKind: string(merged.ScopeKind), + VirtualKeyID: normalizeOptionalString(merged.VirtualKeyID), + ProviderID: normalizeOptionalString(merged.ProviderID), + ProviderKeyID: normalizeOptionalString(merged.ProviderKeyID), + MatchType: string(merged.MatchType), + Pattern: strings.TrimSpace(merged.Pattern), + RequestTypes: merged.RequestTypes, + PricingPatchJSON: pricingPatchJSON, + ConfigHash: existing.ConfigHash, + CreatedAt: existing.CreatedAt, + UpdatedAt: time.Now(), + } + + if err := h.configStore.UpdatePricingOverride(ctx, &override); err != nil { + logger.Error("failed to update pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to update pricing override") + return + } + + if err := h.governanceManager.UpsertPricingOverride(ctx, &override); err != nil { + logger.Error("failed to upsert pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to upsert pricing override") + return + } + SendJSON(ctx, map[string]interface{}{ + "message": "Pricing override updated successfully", + "pricing_override": override, + }) +} + +func (h *GovernanceHandler) deletePricingOverride(ctx *fasthttp.RequestCtx) { + id := ctx.UserValue("id").(string) + if err := h.configStore.DeletePricingOverride(ctx, id); err != nil { + if errors.Is(err, configstore.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, "Pricing override not found") + return + } + logger.Error("failed to delete pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to delete pricing override") + return + } + + if err := h.governanceManager.DeletePricingOverride(ctx, id); err != nil { + logger.Warn("failed to delete pricing override from memory: %v", err) + } + SendJSON(ctx, map[string]interface{}{ + "message": "Pricing override deleted successfully", + }) +} + +func normalizeAndValidatePricingOverrideName(name string) (string, error) { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return "", errors.New("name is required") + } + return trimmed, nil +} + +func normalizeOptionalString(value *string) *string { + if value == nil { + return nil + } + trimmed := strings.TrimSpace(*value) + if trimmed == "" { + return nil + } + return &trimmed +} + // validRoutingScopes contains the allowed scope values for routing rules var validRoutingScopes = map[string]bool{ "global": true, diff --git a/transports/bifrost-http/handlers/governance_test.go b/transports/bifrost-http/handlers/governance_test.go index 0fb9c7e40f..581a22e7b1 100644 --- a/transports/bifrost-http/handlers/governance_test.go +++ b/transports/bifrost-http/handlers/governance_test.go @@ -272,7 +272,7 @@ func TestCollectProviderConfigDeleteIDs(t *testing.T) { { name: "collects both IDs", config: configstoreTables.TableVirtualKeyProviderConfig{ - BudgetID: &budgetID, + Budgets: []configstoreTables.TableBudget{{ID: budgetID}}, RateLimitID: &rateLimitID, }, wantBudgetIDs: []string{budgetID}, @@ -281,7 +281,7 @@ func TestCollectProviderConfigDeleteIDs(t *testing.T) { { name: "appends to existing slices", config: configstoreTables.TableVirtualKeyProviderConfig{ - BudgetID: &budgetID, + Budgets: []configstoreTables.TableBudget{{ID: budgetID}}, RateLimitID: &rateLimitID, }, initialBudgetIDs: []string{"budget-0"}, diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index c5e3cee1c7..8877d5675a 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -92,6 +92,8 @@ var chatParamsKnownFields = map[string]bool{ "presence_penalty": true, "prompt_cache_key": true, "reasoning": true, + "reasoning_effort": true, + "reasoning_max_tokens": true, "response_format": true, "safety_identifier": true, "service_tier": true, @@ -185,6 +187,8 @@ var imageGenerationParamsKnownFields = map[string]bool{ "negative_prompt": true, "num_inference_steps": true, "user": true, + "aspect_ratio": true, + "input_images": true, } // imageEditParamsKnownFields contains known fields for image edit requests @@ -513,6 +517,16 @@ func parseFallbacks(fallbackStrings []string) ([]schemas.Fallback, error) { return fallbacks, nil } +func effectiveStream(bodyStream *bool, bifrostCtx *schemas.BifrostContext) bool { + if bodyStream != nil { + return *bodyStream + } + if v, ok := bifrostCtx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool); ok && v { + return true + } + return false +} + // extractExtraParams processes unknown fields from JSON data into ExtraParams func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]any, error) { // Parse JSON to extract unknown fields @@ -681,7 +695,7 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { provider := string(ctx.QueryArgs().Peek("provider")) // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -739,14 +753,17 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { for i, modelEntry := range resp.Data { provider, modelName := schemas.ParseModelString(modelEntry.ID, "") pricingEntry := h.config.ModelCatalog.GetPricingEntryForModel(modelName, provider) - if pricingEntry == nil && modelEntry.Deployment != nil { - // Retry with deployment - pricingEntry = h.config.ModelCatalog.GetPricingEntryForModel(*modelEntry.Deployment, provider) + if pricingEntry == nil && modelEntry.Alias != nil { + // Retry with alias + pricingEntry = h.config.ModelCatalog.GetPricingEntryForModel(*modelEntry.Alias, provider) } if pricingEntry != nil && modelEntry.Pricing == nil { - pricing := &schemas.Pricing{ - Prompt: bifrost.Ptr(fmt.Sprintf("%.10f", pricingEntry.InputCostPerToken)), - Completion: bifrost.Ptr(fmt.Sprintf("%.10f", pricingEntry.OutputCostPerToken)), + pricing := &schemas.Pricing{} + if pricingEntry.InputCostPerToken != nil { + pricing.Prompt = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.InputCostPerToken)) + } + if pricingEntry.OutputCostPerToken != nil { + pricing.Completion = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.OutputCostPerToken)) } if pricingEntry.InputCostPerImage != nil { pricing.Image = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.InputCostPerImage)) @@ -754,6 +771,9 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) { if pricingEntry.CacheReadInputTokenCost != nil { pricing.InputCacheRead = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.CacheReadInputTokenCost)) } + if pricingEntry.CacheCreationInputTokenCost != nil { + pricing.InputCacheWrite = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.CacheCreationInputTokenCost)) + } resp.Data[i].Pricing = pricing } } @@ -808,7 +828,7 @@ func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -919,12 +939,12 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return } - if req.Stream != nil && *req.Stream { + if effectiveStream(req.Stream, bifrostCtx) { h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx, cancel) return } @@ -1013,13 +1033,13 @@ func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return } - if req.Stream != nil && *req.Stream { + if effectiveStream(req.Stream, bifrostCtx) { h.handleStreamingResponses(ctx, bifrostResponsesReq, bifrostCtx, cancel) return } @@ -1087,7 +1107,7 @@ func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1180,7 +1200,7 @@ func (h *CompletionHandler) rerank(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1252,7 +1272,7 @@ func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1379,7 +1399,7 @@ func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1419,7 +1439,7 @@ func (h *CompletionHandler) countTokens(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -1763,7 +1783,7 @@ func (h *CompletionHandler) imageGeneration(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { cancel() SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -1828,11 +1848,6 @@ func prepareImageEditRequest(ctx *fasthttp.RequestCtx) (*ImageEditHTTPRequest, * editType = typeValues[0] } promptValues := form.Value["prompt"] - if editType != "background_removal" { - if len(promptValues) == 0 || promptValues[0] == "" { - return nil, nil, fmt.Errorf("prompt is required") - } - } var imageFiles []*multipart.FileHeader if imageFilesArray := form.File["image[]"]; len(imageFilesArray) > 0 { imageFiles = imageFilesArray @@ -1976,7 +1991,7 @@ func (h *CompletionHandler) imageEdit(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -2119,7 +2134,7 @@ func (h *CompletionHandler) imageVariation(ctx *fasthttp.RequestCtx) { return } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return @@ -2188,7 +2203,7 @@ func (h *CompletionHandler) videoGeneration(ctx *fasthttp.RequestCtx) { Fallbacks: fallbacks, } - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) if bifrostCtx == nil { cancel() SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2242,7 +2257,7 @@ func (h *CompletionHandler) videoRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2300,7 +2315,7 @@ func (h *CompletionHandler) videoDownload(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2362,7 +2377,7 @@ func (h *CompletionHandler) videoList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2413,7 +2428,7 @@ func (h *CompletionHandler) videoDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2490,7 +2505,7 @@ func (h *CompletionHandler) videoRemix(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2558,7 +2573,7 @@ func (h *CompletionHandler) batchCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2618,7 +2633,7 @@ func (h *CompletionHandler) batchList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2664,7 +2679,7 @@ func (h *CompletionHandler) batchRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2710,7 +2725,7 @@ func (h *CompletionHandler) batchCancel(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2756,7 +2771,7 @@ func (h *CompletionHandler) batchResults(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2845,7 +2860,7 @@ func (h *CompletionHandler) fileUpload(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2911,7 +2926,7 @@ func (h *CompletionHandler) fileList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -2957,7 +2972,7 @@ func (h *CompletionHandler) fileRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3003,7 +3018,7 @@ func (h *CompletionHandler) fileDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3049,7 +3064,7 @@ func (h *CompletionHandler) fileContent(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3110,7 +3125,7 @@ func (h *CompletionHandler) containerCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3169,7 +3184,7 @@ func (h *CompletionHandler) containerList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3216,7 +3231,7 @@ func (h *CompletionHandler) containerRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3263,7 +3278,7 @@ func (h *CompletionHandler) containerDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3360,7 +3375,7 @@ func (h *CompletionHandler) containerFileCreate(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3420,7 +3435,7 @@ func (h *CompletionHandler) containerFileList(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3475,7 +3490,7 @@ func (h *CompletionHandler) containerFileRetrieve(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3530,7 +3545,7 @@ func (h *CompletionHandler) containerFileContent(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -3585,7 +3600,7 @@ func (h *CompletionHandler) containerFileDelete(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/integrations.go b/transports/bifrost-http/handlers/integrations.go index 2e5c7d7aec..da9290f117 100644 --- a/transports/bifrost-http/handlers/integrations.go +++ b/transports/bifrost-http/handlers/integrations.go @@ -12,13 +12,16 @@ import ( // IntegrationHandler manages HTTP requests for AI provider integrations type IntegrationHandler struct { - extensions []integrations.ExtensionRouter - wsResponses *WSResponsesHandler + extensions []integrations.ExtensionRouter + wsResponses *WSResponsesHandler + wsRealtime *WSRealtimeHandler + webrtcRealtime *WebRTCRealtimeHandler + realtimeClientSecrets *RealtimeClientSecretsHandler } // NewIntegrationHandler creates a new integration handler instance. -// wsResponses may be nil if WebSocket support is not configured. -func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler) *IntegrationHandler { +// WebSocket handlers may be nil if WebSocket support is not configured. +func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler, wsRealtime *WSRealtimeHandler, webrtcRealtime *WebRTCRealtimeHandler, realtimeClientSecrets *RealtimeClientSecretsHandler) *IntegrationHandler { // Initialize all available integration routers extensions := []integrations.ExtensionRouter{ integrations.NewOpenAIRouter(client, handlerStore, logger), @@ -37,8 +40,11 @@ func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStor } return &IntegrationHandler{ - extensions: extensions, - wsResponses: wsResponses, + extensions: extensions, + wsResponses: wsResponses, + wsRealtime: wsRealtime, + webrtcRealtime: webrtcRealtime, + realtimeClientSecrets: realtimeClientSecrets, } } @@ -52,6 +58,30 @@ func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...sch if h.wsResponses != nil { h.wsResponses.RegisterRoutes(r, middlewares...) } + if h.wsRealtime != nil { + h.wsRealtime.RegisterRoutes(r, middlewares...) + } + if h.webrtcRealtime != nil { + h.webrtcRealtime.RegisterRoutes(r, middlewares...) + } + if h.realtimeClientSecrets != nil { + h.realtimeClientSecrets.RegisterRoutes(r, middlewares...) + } +} + +func (h *IntegrationHandler) Close() { + if h == nil { + return + } + if h.wsResponses != nil { + h.wsResponses.Close() + } + if h.wsRealtime != nil { + h.wsRealtime.Close() + } + if h.webrtcRealtime != nil { + h.webrtcRealtime.Close() + } } // SetLargePayloadHook sets the large payload detection hook on all integration routers diff --git a/transports/bifrost-http/handlers/logging.go b/transports/bifrost-http/handlers/logging.go index 7d48f0945a..dc33e67d10 100644 --- a/transports/bifrost-http/handlers/logging.go +++ b/transports/bifrost-http/handlers/logging.go @@ -29,6 +29,19 @@ type LoggingHandler struct { config *lib.Config } +// Keep session log page size in one place so the session sheet limit is easy to tune later. +const sessionLogPageLimit = 500 + +func parseParentRequestIDFilter(ctx *fasthttp.RequestCtx) string { + if parentRequestID := string(ctx.QueryArgs().Peek("parent_request_id")); strings.TrimSpace(parentRequestID) != "" { + return parentRequestID + } + if sessionID := string(ctx.QueryArgs().Peek("session_id")); strings.TrimSpace(sessionID) != "" { + return sessionID + } + return "" +} + type RedactedKeysManager interface { GetAllRedactedKeys(ctx context.Context, ids []string) []schemas.Key GetAllRedactedVirtualKeys(ctx context.Context, ids []string) []tables.TableVirtualKey @@ -55,6 +68,8 @@ func (h *LoggingHandler) shouldHideDeletedVirtualKeysInFilters() bool { func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // LLM Log retrieval with filtering, search, and pagination r.GET("/api/logs", lib.ChainMiddlewares(h.getLogs, middlewares...)) + r.GET("/api/logs/sessions/{session_id}/summary", lib.ChainMiddlewares(h.getLogSessionSummaryByID, middlewares...)) + r.GET("/api/logs/sessions/{session_id}", lib.ChainMiddlewares(h.getLogSessionByID, middlewares...)) r.GET("/api/logs/{id}", lib.ChainMiddlewares(h.getLogByID, middlewares...)) r.GET("/api/logs/stats", lib.ChainMiddlewares(h.getLogsStats, middlewares...)) r.GET("/api/logs/histogram", lib.ChainMiddlewares(h.getLogsHistogram, middlewares...)) @@ -65,6 +80,9 @@ func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas r.GET("/api/logs/histogram/cost/by-provider", lib.ChainMiddlewares(h.getLogsProviderCostHistogram, middlewares...)) r.GET("/api/logs/histogram/tokens/by-provider", lib.ChainMiddlewares(h.getLogsProviderTokenHistogram, middlewares...)) r.GET("/api/logs/histogram/latency/by-provider", lib.ChainMiddlewares(h.getLogsProviderLatencyHistogram, middlewares...)) + r.GET("/api/logs/histogram/cost/by-dimension", lib.ChainMiddlewares(h.getLogsDimensionCostHistogram, middlewares...)) + r.GET("/api/logs/histogram/tokens/by-dimension", lib.ChainMiddlewares(h.getLogsDimensionTokenHistogram, middlewares...)) + r.GET("/api/logs/histogram/latency/by-dimension", lib.ChainMiddlewares(h.getLogsDimensionLatencyHistogram, middlewares...)) r.GET("/api/logs/dropped", lib.ChainMiddlewares(h.getDroppedRequests, middlewares...)) r.GET("/api/logs/filterdata", lib.ChainMiddlewares(h.getAvailableFilterData, middlewares...)) r.GET("/api/logs/rankings", lib.ChainMiddlewares(h.getModelRankings, middlewares...)) @@ -81,6 +99,126 @@ func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas r.DELETE("/api/mcp-logs", lib.ChainMiddlewares(h.deleteMCPLogs, middlewares...)) } +// getLogSessionByID handles GET /api/logs/sessions/{session_id} - Get logs in a single session. +func (h *LoggingHandler) getLogSessionByID(ctx *fasthttp.RequestCtx) { + rawSessionID, ok := ctx.UserValue("session_id").(string) + if !ok || strings.TrimSpace(rawSessionID) == "" { + SendError(ctx, fasthttp.StatusBadRequest, "session_id is required") + return + } + + pagination := &logstore.PaginationOptions{ + Limit: sessionLogPageLimit, + Offset: 0, + SortBy: "timestamp", + Order: "asc", + } + if limit := string(ctx.QueryArgs().Peek("limit")); limit != "" { + i, err := strconv.Atoi(limit) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "invalid limit") + return + } + if i <= 0 { + SendError(ctx, fasthttp.StatusBadRequest, "limit must be greater than 0") + return + } + if i > sessionLogPageLimit { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("limit cannot exceed %d", sessionLogPageLimit)) + return + } + pagination.Limit = i + } + if offset := string(ctx.QueryArgs().Peek("offset")); offset != "" { + i, err := strconv.Atoi(offset) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "invalid offset") + return + } + if i < 0 { + SendError(ctx, fasthttp.StatusBadRequest, "offset cannot be negative") + return + } + pagination.Offset = i + } + if order := string(ctx.QueryArgs().Peek("order")); order != "" { + if order != "asc" && order != "desc" { + SendError(ctx, fasthttp.StatusBadRequest, "order must be 'asc' or 'desc'") + return + } + pagination.Order = order + } + + result, err := h.logManager.GetSessionLogs(ctx, rawSessionID, pagination) + if err != nil { + logger.Error("failed to fetch session logs: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Session fetch failed: %v", err)) + return + } + + selectedKeyIDs := make(map[string]struct{}) + virtualKeyIDs := make(map[string]struct{}) + routingRuleIDs := make(map[string]struct{}) + for _, log := range result.Logs { + if log.SelectedKeyID != "" { + selectedKeyIDs[log.SelectedKeyID] = struct{}{} + } + if log.VirtualKeyID != nil && *log.VirtualKeyID != "" { + virtualKeyIDs[*log.VirtualKeyID] = struct{}{} + } + if log.RoutingRuleID != nil && *log.RoutingRuleID != "" { + routingRuleIDs[*log.RoutingRuleID] = struct{}{} + } + } + + toSlice := func(m map[string]struct{}) []string { + if len(m) == 0 { + return nil + } + out := make([]string, 0, len(m)) + for id := range m { + out = append(out, id) + } + return out + } + + redactedKeys := h.redactedKeysManager.GetAllRedactedKeys(ctx, toSlice(selectedKeyIDs)) + redactedVirtualKeys := h.redactedKeysManager.GetAllRedactedVirtualKeys(ctx, toSlice(virtualKeyIDs)) + redactedRoutingRules := h.redactedKeysManager.GetAllRedactedRoutingRules(ctx, toSlice(routingRuleIDs)) + + for i, log := range result.Logs { + if log.SelectedKeyID != "" && log.SelectedKeyName != "" { + result.Logs[i].SelectedKey = findRedactedKey(redactedKeys, log.SelectedKeyID, log.SelectedKeyName) + } + if log.VirtualKeyID != nil && log.VirtualKeyName != nil && *log.VirtualKeyID != "" && *log.VirtualKeyName != "" { + result.Logs[i].VirtualKey = findRedactedVirtualKey(redactedVirtualKeys, *log.VirtualKeyID, *log.VirtualKeyName) + } + if log.RoutingRuleID != nil && log.RoutingRuleName != nil && *log.RoutingRuleID != "" && *log.RoutingRuleName != "" { + result.Logs[i].RoutingRule = findRedactedRoutingRule(redactedRoutingRules, *log.RoutingRuleID, *log.RoutingRuleName) + } + } + + SendJSON(ctx, result) +} + +// getLogSessionSummaryByID handles GET /api/logs/sessions/{session_id}/summary - Get aggregate totals for a single session. +func (h *LoggingHandler) getLogSessionSummaryByID(ctx *fasthttp.RequestCtx) { + rawSessionID, ok := ctx.UserValue("session_id").(string) + if !ok || strings.TrimSpace(rawSessionID) == "" { + SendError(ctx, fasthttp.StatusBadRequest, "session_id is required") + return + } + + result, err := h.logManager.GetSessionSummary(ctx, rawSessionID) + if err != nil { + logger.Error("failed to fetch session summary: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Session summary fetch failed: %v", err)) + return + } + + SendJSON(ctx, result) +} + // getLogs handles GET /api/logs - Get logs with filtering, search, and pagination via query parameters func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { // Parse query parameters into filters @@ -94,12 +232,18 @@ func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { if models := string(ctx.QueryArgs().Peek("models")); models != "" { filters.Models = parseCommaSeparated(models) } + if aliases := string(ctx.QueryArgs().Peek("aliases")); aliases != "" { + filters.Aliases = parseCommaSeparated(aliases) + } if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { filters.Status = parseCommaSeparated(statuses) } if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { filters.Objects = parseCommaSeparated(objects) } + if parentRequestID := parseParentRequestIDFilter(ctx); parentRequestID != "" { + filters.ParentRequestID = parentRequestID + } if selectedKeyIDs := string(ctx.QueryArgs().Peek("selected_key_ids")); selectedKeyIDs != "" { filters.SelectedKeyIDs = parseCommaSeparated(selectedKeyIDs) } @@ -109,6 +253,18 @@ func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { if routingRuleIDs := string(ctx.QueryArgs().Peek("routing_rule_ids")); routingRuleIDs != "" { filters.RoutingRuleIDs = parseCommaSeparated(routingRuleIDs) } + if teamIDs := string(ctx.QueryArgs().Peek("team_ids")); teamIDs != "" { + filters.TeamIDs = parseCommaSeparated(teamIDs) + } + if customerIDs := string(ctx.QueryArgs().Peek("customer_ids")); customerIDs != "" { + filters.CustomerIDs = parseCommaSeparated(customerIDs) + } + if userIDs := string(ctx.QueryArgs().Peek("user_ids")); userIDs != "" { + filters.UserIDs = parseCommaSeparated(userIDs) + } + if businessUnitIDs := string(ctx.QueryArgs().Peek("business_unit_ids")); businessUnitIDs != "" { + filters.BusinessUnitIDs = parseCommaSeparated(businessUnitIDs) + } if routingEngines := string(ctx.QueryArgs().Peek("routing_engine_used")); routingEngines != "" { filters.RoutingEngineUsed = parseCommaSeparated(routingEngines) } @@ -305,12 +461,18 @@ func (h *LoggingHandler) getLogsStats(ctx *fasthttp.RequestCtx) { if models := string(ctx.QueryArgs().Peek("models")); models != "" { filters.Models = parseCommaSeparated(models) } + if aliases := string(ctx.QueryArgs().Peek("aliases")); aliases != "" { + filters.Aliases = parseCommaSeparated(aliases) + } if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { filters.Status = parseCommaSeparated(statuses) } if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { filters.Objects = parseCommaSeparated(objects) } + if parentRequestID := parseParentRequestIDFilter(ctx); parentRequestID != "" { + filters.ParentRequestID = parentRequestID + } if selectedKeyIDs := string(ctx.QueryArgs().Peek("selected_key_ids")); selectedKeyIDs != "" { filters.SelectedKeyIDs = parseCommaSeparated(selectedKeyIDs) } @@ -320,6 +482,18 @@ func (h *LoggingHandler) getLogsStats(ctx *fasthttp.RequestCtx) { if routingRuleIDs := string(ctx.QueryArgs().Peek("routing_rule_ids")); routingRuleIDs != "" { filters.RoutingRuleIDs = parseCommaSeparated(routingRuleIDs) } + if teamIDs := string(ctx.QueryArgs().Peek("team_ids")); teamIDs != "" { + filters.TeamIDs = parseCommaSeparated(teamIDs) + } + if customerIDs := string(ctx.QueryArgs().Peek("customer_ids")); customerIDs != "" { + filters.CustomerIDs = parseCommaSeparated(customerIDs) + } + if userIDs := string(ctx.QueryArgs().Peek("user_ids")); userIDs != "" { + filters.UserIDs = parseCommaSeparated(userIDs) + } + if businessUnitIDs := string(ctx.QueryArgs().Peek("business_unit_ids")); businessUnitIDs != "" { + filters.BusinessUnitIDs = parseCommaSeparated(businessUnitIDs) + } if routingEngines := string(ctx.QueryArgs().Peek("routing_engine_used")); routingEngines != "" { filters.RoutingEngineUsed = parseCommaSeparated(routingEngines) } @@ -434,12 +608,18 @@ func parseHistogramFilters(ctx *fasthttp.RequestCtx) *logstore.SearchFilters { if models := string(ctx.QueryArgs().Peek("models")); models != "" { filters.Models = parseCommaSeparated(models) } + if aliases := string(ctx.QueryArgs().Peek("aliases")); aliases != "" { + filters.Aliases = parseCommaSeparated(aliases) + } if statuses := string(ctx.QueryArgs().Peek("status")); statuses != "" { filters.Status = parseCommaSeparated(statuses) } if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { filters.Objects = parseCommaSeparated(objects) } + if parentRequestID := parseParentRequestIDFilter(ctx); parentRequestID != "" { + filters.ParentRequestID = parentRequestID + } if selectedKeyIDs := string(ctx.QueryArgs().Peek("selected_key_ids")); selectedKeyIDs != "" { filters.SelectedKeyIDs = parseCommaSeparated(selectedKeyIDs) } @@ -449,6 +629,18 @@ func parseHistogramFilters(ctx *fasthttp.RequestCtx) *logstore.SearchFilters { if routingRuleIDs := string(ctx.QueryArgs().Peek("routing_rule_ids")); routingRuleIDs != "" { filters.RoutingRuleIDs = parseCommaSeparated(routingRuleIDs) } + if teamIDs := string(ctx.QueryArgs().Peek("team_ids")); teamIDs != "" { + filters.TeamIDs = parseCommaSeparated(teamIDs) + } + if customerIDs := string(ctx.QueryArgs().Peek("customer_ids")); customerIDs != "" { + filters.CustomerIDs = parseCommaSeparated(customerIDs) + } + if userIDs := string(ctx.QueryArgs().Peek("user_ids")); userIDs != "" { + filters.UserIDs = parseCommaSeparated(userIDs) + } + if businessUnitIDs := string(ctx.QueryArgs().Peek("business_unit_ids")); businessUnitIDs != "" { + filters.BusinessUnitIDs = parseCommaSeparated(businessUnitIDs) + } if routingEngines := string(ctx.QueryArgs().Peek("routing_engine_used")); routingEngines != "" { filters.RoutingEngineUsed = parseCommaSeparated(routingEngines) } @@ -610,6 +802,78 @@ func (h *LoggingHandler) getLogsProviderLatencyHistogram(ctx *fasthttp.RequestCt SendJSON(ctx, result) } +// parseDimension extracts and validates the "dimension" query parameter. +// Returns the validated HistogramDimension and true on success, or sends an error response and returns false. +func parseDimension(ctx *fasthttp.RequestCtx) (logstore.HistogramDimension, bool) { + dim := logstore.HistogramDimension(string(ctx.QueryArgs().Peek("dimension"))) + if dim == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Missing required query parameter: dimension. Valid values: provider, team_id, customer_id, user_id, business_unit_id") + return "", false + } + if !logstore.ValidHistogramDimensions[dim] { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid dimension: %s. Valid values: provider, team_id, customer_id, user_id, business_unit_id", dim)) + return "", false + } + return dim, true +} + +// getLogsDimensionCostHistogram handles GET /api/logs/histogram/cost/by-dimension +// Returns time-bucketed cost data grouped by the dimension specified in the "dimension" query param. +func (h *LoggingHandler) getLogsDimensionCostHistogram(ctx *fasthttp.RequestCtx) { + dimension, ok := parseDimension(ctx) + if !ok { + return + } + filters := parseHistogramFilters(ctx) + bucketSizeSeconds := calculateBucketSize(filters.StartTime, filters.EndTime) + + result, err := h.logManager.GetDimensionCostHistogram(ctx, filters, bucketSizeSeconds, dimension) + if err != nil { + logger.Error("failed to get dimension cost histogram: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Dimension cost histogram calculation failed: %v", err)) + return + } + SendJSON(ctx, result) +} + +// getLogsDimensionTokenHistogram handles GET /api/logs/histogram/tokens/by-dimension +// Returns time-bucketed token usage grouped by the dimension specified in the "dimension" query param. +func (h *LoggingHandler) getLogsDimensionTokenHistogram(ctx *fasthttp.RequestCtx) { + dimension, ok := parseDimension(ctx) + if !ok { + return + } + filters := parseHistogramFilters(ctx) + bucketSizeSeconds := calculateBucketSize(filters.StartTime, filters.EndTime) + + result, err := h.logManager.GetDimensionTokenHistogram(ctx, filters, bucketSizeSeconds, dimension) + if err != nil { + logger.Error("failed to get dimension token histogram: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Dimension token histogram calculation failed: %v", err)) + return + } + SendJSON(ctx, result) +} + +// getLogsDimensionLatencyHistogram handles GET /api/logs/histogram/latency/by-dimension +// Returns time-bucketed latency percentiles grouped by the dimension specified in the "dimension" query param. +func (h *LoggingHandler) getLogsDimensionLatencyHistogram(ctx *fasthttp.RequestCtx) { + dimension, ok := parseDimension(ctx) + if !ok { + return + } + filters := parseHistogramFilters(ctx) + bucketSizeSeconds := calculateBucketSize(filters.StartTime, filters.EndTime) + + result, err := h.logManager.GetDimensionLatencyHistogram(ctx, filters, bucketSizeSeconds, dimension) + if err != nil { + logger.Error("failed to get dimension latency histogram: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Dimension latency histogram calculation failed: %v", err)) + return + } + SendJSON(ctx, result) +} + // getDroppedRequests handles GET /api/logs/dropped - Get the number of dropped requests func (h *LoggingHandler) getDroppedRequests(ctx *fasthttp.RequestCtx) { droppedRequests := h.logManager.GetDroppedRequests(ctx) @@ -636,10 +900,15 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { var ( models []string + aliases []string selectedKeys []logging.KeyPair virtualKeys []logging.KeyPair routingRules []logging.KeyPair routingEngines []string + teams []logging.KeyPair + customers []logging.KeyPair + users []logging.KeyPair + businessUnits []logging.KeyPair metadataKeys map[string][]string mu sync.Mutex ) @@ -653,6 +922,13 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { mu.Unlock() return nil }) + g.Go(func() error { + result := h.logManager.GetAvailableAliases(gCtx) + mu.Lock() + aliases = result + mu.Unlock() + return nil + }) g.Go(func() error { result := h.logManager.GetAvailableSelectedKeys(gCtx) mu.Lock() @@ -681,6 +957,34 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { mu.Unlock() return nil }) + g.Go(func() error { + result := h.logManager.GetAvailableTeams(gCtx) + mu.Lock() + teams = result + mu.Unlock() + return nil + }) + g.Go(func() error { + result := h.logManager.GetAvailableCustomers(gCtx) + mu.Lock() + customers = result + mu.Unlock() + return nil + }) + g.Go(func() error { + result := h.logManager.GetAvailableUsers(gCtx) + mu.Lock() + users = result + mu.Unlock() + return nil + }) + g.Go(func() error { + result := h.logManager.GetAvailableBusinessUnits(gCtx) + mu.Lock() + businessUnits = result + mu.Unlock() + return nil + }) g.Go(func() error { result, err := h.logManager.GetAvailableMetadataKeys(gCtx) if err != nil { @@ -780,7 +1084,7 @@ func (h *LoggingHandler) getAvailableFilterData(ctx *fasthttp.RequestCtx) { if metadataKeys == nil { metadataKeys = make(map[string][]string) } - SendJSON(ctx, map[string]interface{}{"models": models, "selected_keys": selectedKeysArray, "virtual_keys": virtualKeysArray, "routing_rules": routingRulesArray, "routing_engines": routingEngines, "metadata_keys": metadataKeys}) + SendJSON(ctx, map[string]interface{}{"models": models, "aliases": aliases, "selected_keys": selectedKeysArray, "virtual_keys": virtualKeysArray, "routing_rules": routingRulesArray, "routing_engines": routingEngines, "teams": teams, "customers": customers, "users": users, "business_units": businessUnits, "metadata_keys": metadataKeys}) } // deleteLogs handles DELETE /api/logs - Delete logs by their IDs diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index b9409f63ce..31999287fd 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "fmt" - "slices" "sort" "strconv" "time" @@ -20,6 +19,7 @@ import ( configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" + "gorm.io/gorm" ) type MCPManager interface { @@ -27,23 +27,30 @@ type MCPManager interface { RemoveMCPClient(ctx context.Context, id string) error UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error ReconnectMCPClient(ctx context.Context, id string) error + // VerifyPerUserOAuthConnection verifies an MCP server using a temporary access + // token and discovers available tools. The connection is closed after verification. + VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) + // SetClientTools updates the tool map for an existing client. + SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) } // MCPHandler manages HTTP requests for MCP tool operations type MCPHandler struct { - client *bifrost.Bifrost - store *lib.Config - mcpManager MCPManager - oauthHandler *OAuthHandler + client *bifrost.Bifrost + store *lib.Config + mcpManager MCPManager + governanceManager GovernanceManager + oauthHandler *OAuthHandler } // NewMCPHandler creates a new MCP handler instance -func NewMCPHandler(mcpManager MCPManager, client *bifrost.Bifrost, store *lib.Config, oauthHandler *OAuthHandler) *MCPHandler { +func NewMCPHandler(mcpManager MCPManager, governanceManager GovernanceManager, client *bifrost.Bifrost, store *lib.Config, oauthHandler *OAuthHandler) *MCPHandler { return &MCPHandler{ - client: client, - store: store, - mcpManager: mcpManager, - oauthHandler: oauthHandler, + client: client, + store: store, + mcpManager: mcpManager, + governanceManager: governanceManager, + oauthHandler: oauthHandler, } } @@ -57,11 +64,19 @@ func (h *MCPHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.Bif r.POST("/api/mcp/client/{id}/complete-oauth", lib.ChainMiddlewares(h.completeMCPClientOAuth, middlewares...)) } +// MCPVKConfigResponse is a VK assignment enriched with the VK's display name. +type MCPVKConfigResponse struct { + VirtualKeyID string `json:"virtual_key_id"` + VirtualKeyName string `json:"virtual_key_name"` + ToolsToExecute schemas.WhiteList `json:"tools_to_execute"` +} + // MCPClientResponse represents the response structure for MCP clients type MCPClientResponse struct { - Config *schemas.MCPClientConfig `json:"config"` - Tools []schemas.ChatToolFunction `json:"tools"` - State schemas.MCPConnectionState `json:"state"` + Config *schemas.MCPClientConfig `json:"config"` + Tools []schemas.ChatToolFunction `json:"tools"` + State schemas.MCPConnectionState `json:"state"` + VKConfigs []MCPVKConfigResponse `json:"vk_configs"` } // getMCPClients handles GET /api/mcp/clients - Get all MCP clients @@ -190,6 +205,30 @@ func (h *MCPHandler) getMCPClientsPaginated(ctx *fasthttp.RequestCtx, limitStr, connectedClientsMap[client.Config.ID] = client } + // Build VK idβ†’name lookup from in-memory governance data (no extra DB queries) + vkNameByID := make(map[string]string) + if h.governanceManager != nil { + if gd := h.governanceManager.GetGovernanceData(); gd != nil { + for _, vk := range gd.VirtualKeys { + vkNameByID[vk.ID] = vk.Name + } + } + } + + // Batch-fetch all VK assignments for this page in a single query, then group by client ID. + assignmentsByClientID := make(map[uint][]configstoreTables.TableVirtualKeyMCPConfig) + if h.store.ConfigStore != nil { + dbClientIDs := make([]uint, 0, len(dbClients)) + for _, c := range dbClients { + dbClientIDs = append(dbClientIDs, c.ID) + } + if allAssignments, err := h.store.ConfigStore.GetVirtualKeyMCPConfigsByMCPClientIDs(ctx, dbClientIDs); err == nil { + for _, a := range allAssignments { + assignmentsByClientID[a.MCPClientID] = append(assignmentsByClientID[a.MCPClientID], a) + } + } + } + // Convert DB rows to MCPClientConfig and merge with engine state clients := make([]MCPClientResponse, 0, len(dbClients)) for _, dbClient := range dbClients { @@ -198,20 +237,31 @@ func (h *MCPHandler) getMCPClientsPaginated(ctx *fasthttp.RequestCtx, limitStr, isPingAvailable = *dbClient.IsPingAvailable } clientConfig := &schemas.MCPClientConfig{ - ID: dbClient.ClientID, - Name: dbClient.Name, - IsCodeModeClient: dbClient.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), - ConnectionString: dbClient.ConnectionString, - StdioConfig: dbClient.StdioConfig, - AuthType: schemas.MCPAuthType(dbClient.AuthType), - OauthConfigID: dbClient.OauthConfigID, - ToolsToExecute: dbClient.ToolsToExecute, - ToolsToAutoExecute: dbClient.ToolsToAutoExecute, - Headers: dbClient.Headers, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, - ToolPricing: dbClient.ToolPricing, + ID: dbClient.ClientID, + Name: dbClient.Name, + IsCodeModeClient: dbClient.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(dbClient.ConnectionType), + ConnectionString: dbClient.ConnectionString, + StdioConfig: dbClient.StdioConfig, + AuthType: schemas.MCPAuthType(dbClient.AuthType), + OauthConfigID: dbClient.OauthConfigID, + ToolsToExecute: dbClient.ToolsToExecute, + ToolsToAutoExecute: dbClient.ToolsToAutoExecute, + Headers: dbClient.Headers, + AllowedExtraHeaders: dbClient.AllowedExtraHeaders, + IsPingAvailable: &isPingAvailable, + ToolSyncInterval: time.Duration(dbClient.ToolSyncInterval) * time.Minute, + ToolPricing: dbClient.ToolPricing, + AllowOnAllVirtualKeys: dbClient.AllowOnAllVirtualKeys, + } + // Enrich VK assignments using the pre-fetched batch result (no extra DB call per client) + vkConfigs := []MCPVKConfigResponse{} + for _, a := range assignmentsByClientID[dbClient.ID] { + vkConfigs = append(vkConfigs, MCPVKConfigResponse{ + VirtualKeyID: a.VirtualKeyID, + VirtualKeyName: vkNameByID[a.VirtualKeyID], + ToolsToExecute: a.ToolsToExecute, + }) } redactedConfig := h.store.RedactMCPClientConfig(clientConfig) if connectedClient, exists := connectedClientsMap[clientConfig.ID]; exists { @@ -221,15 +271,17 @@ func (h *MCPHandler) getMCPClientsPaginated(ctx *fasthttp.RequestCtx, limitStr, return sortedTools[i].Name < sortedTools[j].Name }) clients = append(clients, MCPClientResponse{ - Config: redactedConfig, - Tools: sortedTools, - State: connectedClient.State, + Config: redactedConfig, + Tools: sortedTools, + State: connectedClient.State, + VKConfigs: vkConfigs, }) } else { clients = append(clients, MCPClientResponse{ - Config: redactedConfig, - Tools: []schemas.ChatToolFunction{}, - State: schemas.MCPConnectionStateError, + Config: redactedConfig, + Tools: []schemas.ChatToolFunction{}, + State: schemas.MCPConnectionStateError, + VKConfigs: vkConfigs, }) } } @@ -280,6 +332,18 @@ type MCPClientRequest struct { OauthConfig *OAuthConfigRequest `json:"oauth_config,omitempty"` } +// MCPVKConfigRequest represents a per-VK tool access config for an MCP client +type MCPVKConfigRequest struct { + VirtualKeyID string `json:"virtual_key_id"` + ToolsToExecute schemas.WhiteList `json:"tools_to_execute"` +} + +// MCPClientUpdateRequest wraps TableMCPClient and adds optional VK assignment management +type MCPClientUpdateRequest struct { + configstoreTables.TableMCPClient + VKConfigs *[]MCPVKConfigRequest `json:"vk_configs,omitempty"` +} + // addMCPClient handles POST /api/mcp/client - Add a new MCP client func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { if h.store.ConfigStore == nil { @@ -303,8 +367,8 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } // Auto-clear tools_to_auto_execute if tools_to_execute is empty // If no tools are allowed to execute, no tools can be auto-executed - if len(req.ToolsToExecute) == 0 { - req.ToolsToAutoExecute = []string{} + if req.ToolsToExecute.IsEmpty() { + req.ToolsToAutoExecute = schemas.WhiteList{} } if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid tools_to_auto_execute: %v", err)) @@ -314,8 +378,99 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) return } + if err := validateAllowedExtraHeaders(req.AllowedExtraHeaders); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid allowed_extra_headers: %v", err)) + return + } + + // Handle per-user OAuth: admin does a test OAuth login to verify the configuration. + // Uses the same pending_oauth pattern as server-level OAuth, but on completion we + // verify the connection, discover tools, save the client, and discard the admin's token. + if req.AuthType == "per_user_oauth" { + if req.OauthConfig == nil { + SendError(ctx, fasthttp.StatusBadRequest, "OAuth configuration is required when auth_type is 'per_user_oauth'") + return + } + + if req.OauthConfig.ClientID == "" && req.ConnectionString.GetValue() == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Either client_id must be provided, or server URL must be set for OAuth discovery and dynamic client registration") + return + } + + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" + } + host := string(ctx.Host()) + redirectURI := fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host) + + flowInitiation, err := h.oauthHandler.InitiateOAuthFlow(ctx, OAuthInitiationRequest{ + ClientID: req.OauthConfig.ClientID, + ClientSecret: req.OauthConfig.ClientSecret, + AuthorizeURL: req.OauthConfig.AuthorizeURL, + TokenURL: req.OauthConfig.TokenURL, + RegistrationURL: req.OauthConfig.RegistrationURL, + RedirectURI: redirectURI, + Scopes: req.OauthConfig.Scopes, + ServerURL: req.ConnectionString.GetValue(), + }) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to initiate OAuth flow: %v", err)) + return + } + + toolSyncInterval := mcp.DefaultToolSyncInterval + if req.ToolSyncInterval != 0 { + toolSyncInterval = time.Duration(req.ToolSyncInterval) * time.Minute + } else { + config, err := h.store.ConfigStore.GetClientConfig(ctx) + if err == nil && config != nil { + toolSyncInterval = time.Duration(config.MCPToolSyncInterval) * time.Minute + } + } + + isPingAvailable := true + if req.IsPingAvailable != nil { + isPingAvailable = *req.IsPingAvailable + } + + pendingConfig := schemas.MCPClientConfig{ + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + IsPingAvailable: &isPingAvailable, + ToolSyncInterval: toolSyncInterval, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + AuthType: schemas.MCPAuthTypePerUserOauth, + OauthConfigID: &flowInitiation.OauthConfigID, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + ToolPricing: req.ToolPricing, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, + AllowOnAllVirtualKeys: req.AllowOnAllVirtualKeys, + } + + if err := h.oauthHandler.StorePendingMCPClient(flowInitiation.OauthConfigID, pendingConfig); err != nil { + logger.Error(fmt.Sprintf("[Add MCP Client] Failed to store pending MCP client: %v", err)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to store pending MCP client: %v", err)) + return + } + + SendJSON(ctx, map[string]any{ + "status": "pending_oauth", + "message": "Test OAuth configuration: please authorize to verify the setup. This login is only used to verify connectivity and discover available tools β€” it will not be saved.", + "oauth_config_id": flowInitiation.OauthConfigID, + "authorize_url": flowInitiation.AuthorizeURL, + "expires_at": flowInitiation.ExpiresAt, + "mcp_client_id": req.ClientID, + }) + return + } - // Check if OAuth flow is needed + // Check if server-level OAuth flow is needed if req.AuthType == "oauth" { if req.OauthConfig == nil { SendError(ctx, fasthttp.StatusBadRequest, "OAuth configuration is required when auth_type is 'oauth'") @@ -375,27 +530,25 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } } - isPingAvailable := true - if req.IsPingAvailable != nil { - isPingAvailable = *req.IsPingAvailable - } - // Store MCP client config in OAuth provider memory (not in database) // It will be stored in database only after OAuth completion pendingConfig := schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ConnectionType: schemas.MCPConnectionType(req.ConnectionType), - ConnectionString: req.ConnectionString, - StdioConfig: req.StdioConfig, - AuthType: schemas.MCPAuthType(req.AuthType), - OauthConfigID: &flowInitiation.OauthConfigID, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, - Headers: req.Headers, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: &flowInitiation.OauthConfigID, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, + ToolPricing: req.ToolPricing, + AllowOnAllVirtualKeys: req.AllowOnAllVirtualKeys, } // Store pending config in database (associated with oauth_config_id for multi-instance support) @@ -432,26 +585,23 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) - // Dereference IsPingAvailable pointer, defaulting to true if nil (new clients default to ping available) - isPingAvailable := true - if req.IsPingAvailable != nil { - isPingAvailable = *req.IsPingAvailable - } schemasConfig := &schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(req.ConnectionType), - ConnectionString: req.ConnectionString, - StdioConfig: req.StdioConfig, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, - Headers: req.Headers, - AuthType: schemas.MCPAuthType(req.AuthType), - OauthConfigID: req.OauthConfigID, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ToolPricing: req.ToolPricing, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, + AuthType: schemas.MCPAuthType(req.AuthType), + OauthConfigID: req.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, + AllowOnAllVirtualKeys: req.AllowOnAllVirtualKeys, } // Creating MCP client config in config store @@ -491,8 +641,8 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid id: %v", err)) return } - // Accept the full table client config to support tool_pricing - var req *configstoreTables.TableMCPClient + // Accept the full table client config to support tool_pricing, plus optional vk_configs + var req MCPClientUpdateRequest if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err)) return @@ -505,8 +655,8 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { } // Auto-clear tools_to_auto_execute if tools_to_execute is empty // If no tools are allowed to execute, no tools can be auto-executed - if len(req.ToolsToExecute) == 0 { - req.ToolsToAutoExecute = []string{} + if req.ToolsToExecute.IsEmpty() { + req.ToolsToAutoExecute = schemas.WhiteList{} } // Validate tools_to_auto_execute if err := validateToolsToAutoExecute(req.ToolsToAutoExecute, req.ToolsToExecute); err != nil { @@ -518,6 +668,10 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid client name: %v", err)) return } + if err := validateAllowedExtraHeaders(req.AllowedExtraHeaders); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid allowed_extra_headers: %v", err)) + return + } // Get existing config to handle redacted values var existingConfig *schemas.MCPClientConfig if h.store.MCPConfig != nil { @@ -534,7 +688,8 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { } // Merge redacted values - preserve old values if incoming values are redacted and unchanged - req = mergeMCPRedactedValues(req, existingConfig, h.store.RedactMCPClientConfig(existingConfig)) + merged := mergeMCPRedactedValues(&req.TableMCPClient, existingConfig, h.store.RedactMCPClientConfig(existingConfig)) + req.TableMCPClient = *merged // Save existing DB config before update so we can rollback if memory update fails var oldDBConfig *configstoreTables.TableMCPClient if h.store.ConfigStore != nil { @@ -547,7 +702,7 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { } // Persist changes to config store if h.store.ConfigStore != nil { - if err := h.store.ConfigStore.UpdateMCPClientConfig(ctx, id, req); err != nil { + if err := h.store.ConfigStore.UpdateMCPClientConfig(ctx, id, &req.TableMCPClient); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to update mcp client config in store: %v", err)) return } @@ -566,25 +721,23 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { } } // Convert to schemas.MCPClientConfig for runtime bifrost client (without tool_pricing) - isPingAvailable := true - if req.IsPingAvailable != nil { - isPingAvailable = *req.IsPingAvailable - } schemasConfig := &schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - ConnectionType: existingConfig.ConnectionType, - ConnectionString: existingConfig.ConnectionString, - StdioConfig: existingConfig.StdioConfig, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, - Headers: req.Headers, - AuthType: existingConfig.AuthType, - OauthConfigID: existingConfig.OauthConfigID, - IsPingAvailable: isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ToolPricing: req.ToolPricing, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + ConnectionType: existingConfig.ConnectionType, + ConnectionString: existingConfig.ConnectionString, + StdioConfig: existingConfig.StdioConfig, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, + Headers: req.Headers, + AllowedExtraHeaders: req.AllowedExtraHeaders, + AuthType: existingConfig.AuthType, + OauthConfigID: existingConfig.OauthConfigID, + IsPingAvailable: req.IsPingAvailable, + ToolSyncInterval: toolSyncInterval, + ToolPricing: req.ToolPricing, + AllowOnAllVirtualKeys: req.AllowOnAllVirtualKeys, } // Update MCP client in memory if err := h.mcpManager.UpdateMCPClient(ctx, id, schemasConfig); err != nil { @@ -599,6 +752,100 @@ func (h *MCPHandler) updateMCPClient(ctx *fasthttp.RequestCtx) { return } + // Manage VK assignments if vk_configs was provided + if req.VKConfigs != nil && h.store.ConfigStore != nil { + current, err := h.store.ConfigStore.GetVirtualKeyMCPConfigsByMCPClientID(ctx, oldDBConfig.ID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("failed to get current VK MCP configs: %v", err)) + return + } + // Index current assignments by VK ID for diffing + currentByVKID := make(map[string]*configstoreTables.TableVirtualKeyMCPConfig, len(current)) + for i := range current { + currentByVKID[current[i].VirtualKeyID] = ¤t[i] + } + // Validate and reject empty/duplicate virtual_key_id entries + seen := make(map[string]struct{}, len(*req.VKConfigs)) + for _, vc := range *req.VKConfigs { + if vc.VirtualKeyID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "virtual_key_id must not be empty") + return + } + if _, exists := seen[vc.VirtualKeyID]; exists { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("duplicate virtual_key_id in vk_configs: %s", vc.VirtualKeyID)) + return + } + seen[vc.VirtualKeyID] = struct{}{} + } + // Validate tools_to_execute before entering the transaction so failures return 400 + for _, vc := range *req.VKConfigs { + if err := vc.ToolsToExecute.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid tools_to_execute for virtual key %s: %v", vc.VirtualKeyID, err)) + return + } + } + // Index requested assignments by VK ID + requestedByVKID := make(map[string]MCPVKConfigRequest, len(*req.VKConfigs)) + for _, vc := range *req.VKConfigs { + requestedByVKID[vc.VirtualKeyID] = vc + } + if err := h.store.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { + // Create or update + for _, vc := range *req.VKConfigs { + if existing, ok := currentByVKID[vc.VirtualKeyID]; ok { + existing.ToolsToExecute = vc.ToolsToExecute + if err := h.store.ConfigStore.UpdateVirtualKeyMCPConfig(ctx, existing, tx); err != nil { + return fmt.Errorf("failed to update VK MCP config for %s: %w", vc.VirtualKeyID, err) + } + } else { + if err := h.store.ConfigStore.CreateVirtualKeyMCPConfig(ctx, &configstoreTables.TableVirtualKeyMCPConfig{ + VirtualKeyID: vc.VirtualKeyID, + MCPClientID: oldDBConfig.ID, + ToolsToExecute: vc.ToolsToExecute, + }, tx); err != nil { + return fmt.Errorf("failed to create VK MCP config for %s: %w", vc.VirtualKeyID, err) + } + } + } + // Delete removed assignments + for vkID, existing := range currentByVKID { + if _, ok := requestedByVKID[vkID]; !ok { + if err := h.store.ConfigStore.DeleteVirtualKeyMCPConfig(ctx, existing.ID, tx); err != nil { + return fmt.Errorf("failed to remove VK MCP config for %s: %w", vkID, err) + } + } + } + return nil + }); err != nil { + // NOTE: Partial success β€” the MCP client config was already updated in DB and memory above. + // Only the VK assignment changes failed. The VK assignments remain unchanged in DB. + // The MCP client update is idempotent, so retrying the full request is safe. + logger.Error(fmt.Sprintf( + "[PARTIAL SUCCESS] MCP client %s was updated successfully but VK assignment update failed: %v. "+ + "VK assignments remain unchanged. Retry the request to apply VK changes.", + id, err, + )) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("MCP client was updated but VK assignment update failed: %v", err)) + return + } + // Reload all affected VKs in memory so governance enforcement reflects the new MCP assignments. + // requestedByVKID and currentByVKID together cover the full affected set (no duplicates since both are maps). + if h.governanceManager != nil { + for vkID := range requestedByVKID { + if _, err := h.governanceManager.ReloadVirtualKey(ctx, vkID); err != nil { + logger.Error(fmt.Sprintf("failed to reload virtual key %s in memory after MCP VK assignment update: %v", vkID, err)) + } + } + for vkID := range currentByVKID { + if _, alreadyReloaded := requestedByVKID[vkID]; !alreadyReloaded { + if _, err := h.governanceManager.ReloadVirtualKey(ctx, vkID); err != nil { + logger.Error(fmt.Sprintf("failed to reload virtual key %s in memory after MCP VK assignment update: %v", vkID, err)) + } + } + } + } + } + SendJSON(ctx, map[string]any{ "status": "success", "message": "MCP client edited successfully", @@ -647,64 +894,37 @@ func getIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) { return idStr, nil } -func validateToolsToExecute(toolsToExecute []string) error { - if len(toolsToExecute) > 0 { - // Check if wildcard "*" is combined with other tool names - hasWildcard := slices.Contains(toolsToExecute, "*") - if hasWildcard && len(toolsToExecute) > 1 { - return fmt.Errorf("invalid tools_to_execute: wildcard '*' cannot be combined with other tool names") - } - - // Check for duplicate entries - seen := make(map[string]bool) - for _, tool := range toolsToExecute { - if seen[tool] { - return fmt.Errorf("invalid tools_to_execute: duplicate tool name '%s'", tool) - } - seen[tool] = true - } +func validateToolsToExecute(toolsToExecute schemas.WhiteList) error { + if err := toolsToExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_execute: %w", err) } - return nil } -func validateToolsToAutoExecute(toolsToAutoExecute []string, toolsToExecute []string) error { - if len(toolsToAutoExecute) > 0 { - // Check if wildcard "*" is combined with other tool names - hasWildcard := slices.Contains(toolsToAutoExecute, "*") - if hasWildcard && len(toolsToAutoExecute) > 1 { - return fmt.Errorf("wildcard '*' cannot be combined with other tool names") - } +func validateAllowedExtraHeaders(allowedExtraHeaders schemas.WhiteList) error { + if err := allowedExtraHeaders.Validate(); err != nil { + return fmt.Errorf("invalid allowed_extra_headers: %w", err) + } + return nil +} - // Check for duplicate entries - seen := make(map[string]bool) - for _, tool := range toolsToAutoExecute { - if seen[tool] { - return fmt.Errorf("duplicate tool name '%s'", tool) - } - seen[tool] = true - } +func validateToolsToAutoExecute(toolsToAutoExecute schemas.WhiteList, toolsToExecute schemas.WhiteList) error { + if err := toolsToAutoExecute.Validate(); err != nil { + return fmt.Errorf("invalid tools_to_auto_execute: %w", err) + } - // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute - // Create a set of allowed tools from ToolsToExecute - allowedTools := make(map[string]bool) - hasWildcardInExecute := slices.Contains(toolsToExecute, "*") - if hasWildcardInExecute { - // If "*" is in ToolsToExecute, all tools are allowed + if !toolsToAutoExecute.IsEmpty() { + // If ToolsToExecute allows all, no further cross-validation needed + if toolsToExecute.IsUnrestricted() { return nil } - for _, tool := range toolsToExecute { - allowedTools[tool] = true - } - // Validate each tool in ToolsToAutoExecute + // Check that all tools in ToolsToAutoExecute are also in ToolsToExecute for _, tool := range toolsToAutoExecute { if tool == "*" { - // Wildcard is allowed if "*" is in ToolsToExecute - if !hasWildcardInExecute { - return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) - } - } else if !allowedTools[tool] { + return fmt.Errorf("tool '*' in tools_to_auto_execute requires '*' in tools_to_execute") + } + if !toolsToExecute.Contains(tool) { return fmt.Errorf("tool '%s' in tools_to_auto_execute is not in tools_to_execute", tool) } } @@ -750,7 +970,11 @@ func mergeMCPRedactedValues(incoming *configstoreTables.TableMCPClient, oldRaw, // Preserve IsPingAvailable if not explicitly set in incoming request // This prevents the zero-value (false) from overwriting the existing DB value if incoming.IsPingAvailable == nil { - merged.IsPingAvailable = bifrost.Ptr(oldRaw.IsPingAvailable) + merged.IsPingAvailable = oldRaw.IsPingAvailable + } + // Preserve AllowedExtraHeaders if not explicitly set in incoming request + if incoming.AllowedExtraHeaders == nil { + merged.AllowedExtraHeaders = oldRaw.AllowedExtraHeaders } return merged @@ -801,7 +1025,68 @@ func (h *MCPHandler) completeMCPClientOAuth(ctx *fasthttp.RequestCtx) { return } - // Creating MCP client config in config store + // Handle per-user OAuth completion: verify connection with admin's temp token, + // discover tools, create client (without persistent connection), discard token. + if mcpClientConfig.AuthType == schemas.MCPAuthTypePerUserOauth { + // Get admin's temporary access token for verification + accessToken, err := h.oauthHandler.GetAccessToken(ctx, oauthConfigID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get admin access token for verification: %v", err)) + return + } + // Always clean up admin's temp token and pending config, even on failure + defer h.oauthHandler.RevokeToken(ctx, oauthConfigID) + defer h.oauthHandler.RemovePendingMCPClient(oauthConfigID) + + // Verify connection and discover tools using admin's temp token + tools, toolNameMapping, err := h.mcpManager.VerifyPerUserOAuthConnection(ctx, mcpClientConfig, accessToken) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("OAuth configuration test failed: %v", err)) + return + } + + // Persist MCP client config in config store + if h.store.ConfigStore != nil { + if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create MCP config: %v", err)) + return + } + } + + // Add MCP client to manager (skips connection for per_user_oauth) + if err := h.mcpManager.AddMCPClient(ctx, mcpClientConfig); err != nil { + // Clean up DB entry on failure + if h.store.ConfigStore != nil { + if delErr := h.store.ConfigStore.DeleteMCPClientConfig(ctx, mcpClientConfig.ID); delErr != nil { + logger.Error(fmt.Sprintf("Failed to delete MCP client config from database: %v. please restart bifrost to keep core and database in sync", delErr)) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to delete MCP client config from database: %v. please restart bifrost to keep core and database in sync", delErr)) + return + } + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to register MCP client: %v", err)) + return + } + + // Set discovered tools on the client + h.mcpManager.SetClientTools(mcpClientConfig.ID, tools, toolNameMapping) + + // Persist discovered tools to DB so they survive restart + if h.store.ConfigStore != nil { + if err := h.store.ConfigStore.UpdateMCPClientDiscoveredTools(ctx, mcpClientConfig.ID, tools, toolNameMapping); err != nil { + logger.Warn(fmt.Sprintf("[OAuth Complete] Failed to persist discovered tools for %s: %v", mcpClientConfig.ID, err)) + } + } + + logger.Debug(fmt.Sprintf("[OAuth Complete] Per-user OAuth MCP client verified and created: %s (%d tools)", mcpClientConfig.ID, len(tools))) + SendJSON(ctx, map[string]any{ + "status": "success", + "message": fmt.Sprintf("OAuth configuration verified successfully. %d tools discovered. Each user will authenticate individually when using this MCP server.", len(tools)), + "tools_count": len(tools), + }) + return + } + + // Standard server-level OAuth completion if h.store.ConfigStore != nil { if err := h.store.ConfigStore.CreateMCPClientConfig(ctx, mcpClientConfig); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create MCP config: %v", err)) diff --git a/transports/bifrost-http/handlers/mcpinference.go b/transports/bifrost-http/handlers/mcpinference.go index 80856dd8a3..4e80e18d5d 100644 --- a/transports/bifrost-http/handlers/mcpinference.go +++ b/transports/bifrost-http/handlers/mcpinference.go @@ -14,14 +14,14 @@ import ( type MCPInferenceHandler struct { client *bifrost.Bifrost - store *lib.Config + config *lib.Config } // NewMCPInferenceHandler creates a new MCP inference handler instance -func NewMCPInferenceHandler(client *bifrost.Bifrost, store *lib.Config) *MCPInferenceHandler { +func NewMCPInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *MCPInferenceHandler { return &MCPInferenceHandler{ client: client, - store: store, + config: config, } } @@ -60,7 +60,7 @@ func (h *MCPInferenceHandler) executeChatMCPTool(ctx *fasthttp.RequestCtx) { } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.store.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") @@ -93,7 +93,7 @@ func (h *MCPInferenceHandler) executeResponsesMCPTool(ctx *fasthttp.RequestCtx) } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.store.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() // Ensure cleanup on function exit if bifrostCtx == nil { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index 31e9c448d8..f3214e801c 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -5,9 +5,9 @@ package handlers import ( "context" "fmt" - "slices" "strings" "sync" + "time" "github.com/bytedance/sonic" "github.com/fasthttp/router" @@ -64,6 +64,9 @@ func NewMCPServerHandler(ctx context.Context, config *lib.Config, toolManager MC // Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer) + // Register per-request tool filter so x-bf-mcp-include-clients and x-bf-mcp-include-tools are respected on tools/list + server.WithToolFilter(handler.makeIncludeClientsFilter())(handler.globalMCPServer) + if err := handler.SyncAllMCPServers(ctx); err != nil { return nil, fmt.Errorf("failed to sync all MCP servers: %w", err) } @@ -79,17 +82,44 @@ func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schem } // handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages +// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth +// session exists, injects the session token and identity (VK / User ID) directly +// into the BifrostContext. This avoids header-based identity propagation which +// would be vulnerable to spoofing by upstream callers. +// +// Governance context keys are set here intentionally (bypassing governance plugin) +// because in the MCP gateway path, identity is pre-authenticated via the OAuth session. +func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) { + bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true) + if session != nil { + if session.AccessToken != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken) + } + if session.VirtualKeyID != nil && *session.VirtualKeyID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, *session.VirtualKeyID) + if session.VirtualKey != nil && session.VirtualKey.Name != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyName, session.VirtualKey.Name) + } + } + if session.UserID != nil && *session.UserID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceUserID, *session.UserID) + } + } +} + func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { - mcpServer, err := h.getMCPServerForRequest(ctx) + mcpServer, session, err := h.getMCPServerForRequest(ctx) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) return } // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() + injectMCPSessionIdentity(bifrostCtx, session) + // Use mcp-go server to handle the request // HandleMessage processes JSON-RPC messages and returns appropriate responses response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody()) @@ -114,7 +144,7 @@ func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { // handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { - _, err := h.getMCPServerForRequest(ctx) + _, session, err := h.getMCPServerForRequest(ctx) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) return @@ -126,7 +156,9 @@ func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("Connection", "keep-alive") // Convert context - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + + injectMCPSessionIdentity(bifrostCtx, session) // Use SSEStreamReader to bypass fasthttp's internal pipe batching reader := lib.NewSSEStreamReader() @@ -235,7 +267,7 @@ func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools [ handler := func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Inject tool filter into execution context if present if toolFilter != nil { - ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), toolFilter) + ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, toolFilter) } // Convert to Bifrost tool call format toolCallType := "function" @@ -256,6 +288,12 @@ func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools [ // Execute the tool via tool executor toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall) if err != nil { + if err.ExtraFields.MCPAuthRequired != nil { + return mcp.NewToolResultError(fmt.Sprintf( + "Authentication required for %s. Open this URL to connect your account: %s", + err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL, + )), nil + } return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil } @@ -323,35 +361,50 @@ func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schema ctx := context.Background() var toolFilter []string - if len(vk.MCPConfigs) > 0 { - executeOnlyTools := make([]string, 0) - for _, vkMcpConfig := range vk.MCPConfigs { - if len(vkMcpConfig.ToolsToExecute) == 0 { - // No tools specified in virtual key config - skip this client entirely - continue - } + executeOnlyTools := make([]string, 0) - // Handle wildcard in virtual key config - allow all tools from this client - if slices.Contains(vkMcpConfig.ToolsToExecute, "*") { - // Virtual key uses wildcard - use client-specific wildcard - executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) - continue - } + // Build a lookup of AllowOnAllVirtualKeys clients: clientID -> clientName. + // Explicit VK MCPConfigs always take precedence over AllowOnAllVirtualKeys. + allowAllVKsClients := h.config.GetAllowOnAllVirtualKeysClients() + if allowAllVKsClients == nil { + allowAllVKsClients = make(map[string]string) + } - for _, tool := range vkMcpConfig.ToolsToExecute { - if tool != "" { - // Add the tool - client config filtering will be handled by mcp.go - // Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*") - executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) - } + // Process explicit VK MCPConfigs first. + handledClients := make(map[string]bool) + for _, vkMcpConfig := range vk.MCPConfigs { + clientID := vkMcpConfig.MCPClient.ClientID + if _, isAllowAll := allowAllVKsClients[clientID]; isAllowAll { + // Explicit config exists β€” it takes precedence; mark handled regardless of tool list. + handledClients[clientID] = true + } + if vkMcpConfig.ToolsToExecute.IsEmpty() { + continue + } + if vkMcpConfig.ToolsToExecute.IsUnrestricted() { + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", vkMcpConfig.MCPClient.Name)) + continue + } + for _, tool := range vkMcpConfig.ToolsToExecute { + if tool != "" { + // Add the tool - client config filtering will be handled by mcp.go + // Note: Use '-' separator for individual tools (wildcard uses '-*' after client name, e.g., "client-*") + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-%s", vkMcpConfig.MCPClient.Name, tool)) } } + } - // Set even when empty to exclude tools when no tools are present in the virtual key config - ctx = context.WithValue(ctx, schemas.BifrostContextKey("mcp-include-tools"), executeOnlyTools) - toolFilter = executeOnlyTools + // For AllowOnAllVirtualKeys clients with no explicit VK config, allow all their tools. + for clientID, clientName := range allowAllVKsClients { + if !handledClients[clientID] { + executeOnlyTools = append(executeOnlyTools, fmt.Sprintf("%s-*", clientName)) + } } + // Always set the include-tools filter (empty = deny-all when no MCPConfigs and no AllowOnAllVirtualKeys clients) + ctx = context.WithValue(ctx, schemas.MCPContextKeyIncludeTools, executeOnlyTools) + toolFilter = executeOnlyTools + return h.toolManager.GetAvailableMCPTools(ctx), toolFilter } @@ -360,7 +413,7 @@ func (h *MCPServerHandler) fetchToolsForVK(vk *tables.TableVirtualKey) ([]schema // When neither header is present the filter is a no-op, preserving existing behaviour. func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc { return func(ctx context.Context, tools []mcp.Tool) []mcp.Tool { - if ctx.Value(schemas.BifrostContextKey("mcp-include-clients")) == nil && ctx.Value(schemas.BifrostContextKey("mcp-include-tools")) == nil { + if ctx.Value(schemas.MCPContextKeyIncludeClients) == nil && ctx.Value(schemas.MCPContextKeyIncludeTools) == nil { return tools } allowed := h.toolManager.GetAvailableMCPTools(ctx) @@ -382,7 +435,7 @@ func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc { // Utility methods -func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, error) { +func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) { h.mu.RLock() defer h.mu.RUnlock() @@ -392,23 +445,92 @@ func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*se vk := getVKFromRequest(ctx) + // Check for Bifrost per-user OAuth Bearer token (not a VK) + userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx) + if sessionErr != nil { + return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr) + } + + // If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery + if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" { + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" + } + host := string(ctx.Host()) + resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host) + ctx.Response.Header.Set("WWW-Authenticate", + fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL)) + return nil, nil, fmt.Errorf("oauth authentication required for mcp access") + } + + if userOauthSession != nil { + if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") { + return h.globalMCPServer, userOauthSession, nil + } + + if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil { + return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key") + } + + vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value] + if !ok { + return nil, nil, fmt.Errorf("virtual key not found") + } + + return vkServer, userOauthSession, nil + } + // Return global MCP server if not enforcing virtual key header and no virtual key is provided if !enforceVK && vk == "" { - return h.globalMCPServer, nil + return h.globalMCPServer, nil, nil } - // Check if virtual key is provided if vk == "" { - return nil, fmt.Errorf("virtual key header is required to access MCP server.") + return nil, nil, fmt.Errorf("virtual key header required to access mcp server") } - // Check if vk exists in the map vkServer, ok := h.vkMCPServers[vk] if !ok { - return nil, fmt.Errorf("virtual key not found.") + return nil, nil, fmt.Errorf("virtual key not found") + } + + return vkServer, nil, nil +} + +// getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth +// token from the Authorization header. Returns the session if valid, nil otherwise. +func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) { + authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization"))) + if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return nil, nil + } + token := strings.TrimSpace(authHeader[7:]) + if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) { + return nil, nil // It's a virtual key, not a per-user OAuth token + } + + if h.config.ConfigStore == nil { + return nil, nil + } + + session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token) + if err != nil { + logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err) + return nil, err + } + if session == nil { + logger.Debug("[mcp/auth] Session not found for token") + return nil, nil + } + + // Check expiry + if session.ExpiresAt.Before(time.Now()) { + logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt) + return nil, nil } - return vkServer, nil + return session, nil } func getVKFromRequest(ctx *fasthttp.RequestCtx) string { diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 9ada2d0ad7..05b8c30813 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -12,16 +12,19 @@ import ( "sync/atomic" "time" + "github.com/google/uuid" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/tracing" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) var loggingSkipPaths = []string{"/health", "/_next", "/api/dev"} +var realtimeTransportPaths = buildRealtimeTransportPathSet() // SecurityHeadersMiddleware sets security-related HTTP headers on every response. // This should wrap the outermost handler so all responses (API, UI, errors) include these headers. @@ -82,7 +85,7 @@ func CorsMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { isLocalhostOrigin(origin) || slices.Contains(config.ClientConfig.AllowedOrigins, origin) - allowedHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key"} + allowedHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key", "X-OpenAI-Agents-SDK"} if slices.Contains(config.ClientConfig.AllowedHeaders, "*") { if credentialed { // Per the Fetch spec, Access-Control-Allow-Headers: * is NOT treated as a @@ -315,20 +318,33 @@ func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddl fasthttpToHTTPRequest(ctx, req) // Run plugin interceptors for _, plugin := range plugins { - resp, err := plugin.HTTPTransportPreHook(bifrostCtx, req) + pluginName := plugin.GetName() + pluginCtx := bifrostCtx.WithPluginScope(&pluginName) + resp, err := plugin.HTTPTransportPreHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() if err != nil { - // Short-circuit with error + // Short-circuit with error β€” drain plugin logs before returning + if logs := bifrostCtx.DrainPluginLogs(); len(logs) > 0 { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, logs) + } ctx.SetStatusCode(fasthttp.StatusInternalServerError) ctx.SetBodyString(err.Error()) return } if resp != nil { - // Short-circuit with response + // Short-circuit with response β€” drain plugin logs before returning + if logs := bifrostCtx.DrainPluginLogs(); len(logs) > 0 { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, logs) + } applyHTTPResponseToCtx(ctx, resp) return } // If we got here, the plugin may have modified req in-place } + // Drain pre-hook plugin logs and store on fasthttp context for trace attachment + if preHookLogs := bifrostCtx.DrainPluginLogs(); len(preHookLogs) > 0 { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, preHookLogs) + } // Apply modifications back to fasthttp context applyHTTPRequestToCtx(ctx, req) // Adding user values @@ -337,31 +353,70 @@ func TransportInterceptorMiddleware(config *lib.Config) schemas.BifrostHTTPMiddl } next(ctx) - // Skip HTTPTransportPostHook for streaming responses - // Streaming handlers set DeferTraceCompletion and use StreamChunkInterceptor for per-chunk hooks + // For streaming responses, store a callback to run post-hooks after the stream ends. + // The streaming handler calls this before traceCompleter. if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPostHookCompleter, func() { + runTransportPostHooks(ctx, plugins, bifrostCtx) + }) return } - // Acquire pooled response for post-hooks (non-streaming only) - httpResp := schemas.AcquireHTTPResponse() - defer schemas.ReleaseHTTPResponse(httpResp) - fasthttpResponseToHTTPResponse(ctx, httpResp) - // Run http post-hooks in reverse order - for i := len(plugins) - 1; i >= 0; i-- { - plugin := plugins[i] - err := plugin.HTTPTransportPostHook(bifrostCtx, req, httpResp) - if err != nil { - logger.Warn("error in HTTPTransportPostHook for plugin %s: %s", plugin.GetName(), err.Error()) - // Short-circuit with response - applyHTTPResponseToCtx(ctx, httpResp) - return + runTransportPostHooks(ctx, plugins, bifrostCtx) + } + } +} + +// runTransportPostHooks runs HTTPTransportPostHook for all plugins in reverse order, +// drains plugin logs, and applies the response back to the fasthttp context. +// Used for both non-streaming (inline) and streaming (deferred callback) paths. +// +// Transport-level plugin logs are stored in fasthttp UserValues (keyed by +// BifrostContextKeyTransportPluginLogs) rather than directly on BifrostContext, +// because transport hooks operate at the fasthttp layer before/after the core +// BifrostContext lifecycle. These logs are merged into the trace by the +// TracingMiddleware at trace completion, alongside core-level plugin logs +// which travel through BifrostContext β†’ Trace β†’ AttachPluginLogs. +func runTransportPostHooks(ctx *fasthttp.RequestCtx, plugins []schemas.HTTPTransportPlugin, bifrostCtx *schemas.BifrostContext) { + httpResp := schemas.AcquireHTTPResponse() + defer schemas.ReleaseHTTPResponse(httpResp) + fasthttpResponseToHTTPResponse(ctx, httpResp) + + // Build request from current fasthttp state (original pooled req may have been released) + req := schemas.AcquireHTTPRequest() + defer schemas.ReleaseHTTPRequest(req) + fasthttpToHTTPRequest(ctx, req) + + // Run http post-hooks in reverse order + for i := len(plugins) - 1; i >= 0; i-- { + plugin := plugins[i] + pluginName := plugin.GetName() + pluginCtx := bifrostCtx.WithPluginScope(&pluginName) + err := plugin.HTTPTransportPostHook(pluginCtx, req, httpResp) + pluginCtx.ReleasePluginScope() + if err != nil { + logger.Warn("error in HTTPTransportPostHook for plugin %s: %s", pluginName, err.Error()) + // Drain plugin logs before returning on error + if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { + if existing, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, append(existing, postHookLogs...)) + } else { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, postHookLogs) } } - // Apply modifications back to fasthttp context applyHTTPResponseToCtx(ctx, httpResp) + return + } + } + // Drain post-hook plugin logs and merge with pre-hook logs + if postHookLogs := bifrostCtx.DrainPluginLogs(); len(postHookLogs) > 0 { + if existing, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, append(existing, postHookLogs...)) + } else { + ctx.SetUserValue(schemas.BifrostContextKeyTransportPluginLogs, postHookLogs) } } + applyHTTPResponseToCtx(ctx, httpResp) } // getBifrostContextFromFastHTTP gets or creates a BifrostContext from fasthttp context. @@ -501,7 +556,41 @@ func validateSession(_ *fasthttp.RequestCtx, store configstore.ConfigStore, toke // isInferenceWSEndpoint returns true for WebSocket endpoints that should use // standard inference auth (Bearer/Basic/VK) rather than dashboard session tokens. func isInferenceWSEndpoint(path string) bool { - return path == "/v1/responses" || path == "/v1/realtime" + for strings.HasPrefix(path, "/openai/") { + path = strings.TrimPrefix(path, "/openai") + } + + switch path { + case "/v1/responses", + "/responses", + "/v1/realtime", + "/realtime": + return true + default: + return false + } +} + +func buildRealtimeTransportPathSet() map[string]struct{} { + paths := map[string]struct{}{} + for _, path := range integrations.OpenAIRealtimePaths("") { + paths[path] = struct{}{} + } + for _, path := range integrations.OpenAIRealtimePaths("/openai") { + paths[path] = struct{}{} + } + for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("") { + paths[path] = struct{}{} + } + for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("/openai") { + paths[path] = struct{}{} + } + return paths +} + +func isRealtimeTransportEndpoint(path string) bool { + _, ok := realtimeTransportPaths[path] + return ok } // AuthMiddleware is a middleware that handles authentication for the API. @@ -614,6 +703,10 @@ func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, str next(ctx) return } + if isRealtimeTransportEndpoint(string(ctx.Path())) { + next(ctx) + return + } // If inference is disabled, we skip authorization // Get the authorization header authorization := string(ctx.Request.Header.Peek("Authorization")) @@ -773,24 +866,23 @@ func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, str // // This middleware should be placed early in the middleware chain to capture the full request lifecycle. type TracingMiddleware struct { - tracer atomic.Pointer[tracing.Tracer] - obsPlugins atomic.Pointer[[]schemas.ObservabilityPlugin] + tracer atomic.Pointer[tracing.Tracer] } // NewTracingMiddleware creates a new tracing middleware -func NewTracingMiddleware(tracer *tracing.Tracer, obsPlugins []schemas.ObservabilityPlugin) *TracingMiddleware { +func NewTracingMiddleware(tracer *tracing.Tracer) *TracingMiddleware { tm := &TracingMiddleware{ - tracer: atomic.Pointer[tracing.Tracer]{}, - obsPlugins: atomic.Pointer[[]schemas.ObservabilityPlugin]{}, + tracer: atomic.Pointer[tracing.Tracer]{}, } tm.tracer.Store(tracer) - tm.obsPlugins.Store(&obsPlugins) return tm } // SetObservabilityPlugins sets the observability plugins for the tracing middleware func (m *TracingMiddleware) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { - m.obsPlugins.Store(&obsPlugins) + if tracer := m.tracer.Load(); tracer != nil { + tracer.SetObservabilityPlugins(obsPlugins) + } } // SetTracer sets the tracer for the tracing middleware @@ -802,19 +894,26 @@ func (m *TracingMiddleware) SetTracer(tracer *tracing.Tracer) { func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - // Skip if store is nil - if m.tracer.Load() == nil { + // Pin the tracer for the lifetime of this request so that a concurrent + // SetTracer() swap cannot split a trace across two instances. + tracer := m.tracer.Load() + if tracer == nil { next(ctx) return } + requestID := string(ctx.Request.Header.Peek("x-request-id")) + if requestID == "" { + requestID = uuid.New().String() + // Injecting this back to be picked up by the next middleware + ctx.Request.Header.Set("x-request-id", requestID) + } // Extract trace ID from W3C traceparent header (if present) // This is the 32-char trace ID that links all spans in a distributed trace inheritedTraceID := tracing.ExtractParentID(&ctx.Request.Header) // Create trace in store - only ID returned (trace data stays in store) - traceID := m.tracer.Load().CreateTrace(inheritedTraceID) + traceID := tracer.CreateTrace(inheritedTraceID, requestID) // Only trace ID goes into context (lightweight, no bloat) ctx.SetUserValue(schemas.BifrostContextKeyTraceID, traceID) - // Extract parent span ID from W3C traceparent header (if present) // This is the 16-char span ID from the upstream service that should be // set as the ParentID of our root span for proper trace linking in Datadog/etc. @@ -825,14 +924,22 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { // Store a trace completion callback for streaming handlers to use ctx.SetUserValue(schemas.BifrostContextKeyTraceCompleter, func() { - m.completeAndFlushTrace(traceID) + // Run deferred HTTPTransportPostHook for streaming responses + if postHookCompleter, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPostHookCompleter).(func()); ok { + postHookCompleter() + } + // Attach transport plugin logs before completing the trace (streaming path) + if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { + tracer.AttachPluginLogs(traceID, transportLogs) + } + tracer.CompleteAndFlushTrace(traceID) }) // Create root span for the HTTP request - spanCtx, rootSpan := m.tracer.Load().StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) + spanCtx, rootSpan := tracer.StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) if rootSpan != nil { - m.tracer.Load().SetAttribute(rootSpan, "http.method", string(ctx.Method())) - m.tracer.Load().SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) - m.tracer.Load().SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) + tracer.SetAttribute(rootSpan, "http.method", string(ctx.Method())) + tracer.SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) + tracer.SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) // Set root span ID in context for child span creation if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetUserValue(schemas.BifrostContextKeySpanID, spanID) @@ -841,11 +948,11 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { defer func() { // Record response status on the root span if rootSpan != nil { - m.tracer.Load().SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) + tracer.SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) if ctx.Response.StatusCode() >= 400 { - m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) + tracer.EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) } else { - m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusOk, "") + tracer.EndSpan(rootSpan, schemas.SpanStatusOk, "") } } // Check if trace completion is deferred (for streaming requests) @@ -853,8 +960,12 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { if deferred, ok := ctx.UserValue(schemas.BifrostContextKeyDeferTraceCompletion).(bool); ok && deferred { return } + // Attach transport plugin logs to trace before completion + if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { + tracer.AttachPluginLogs(traceID, transportLogs) + } // After response written - async flush - m.completeAndFlushTrace(traceID) + tracer.CompleteAndFlushTrace(traceID) }() next(ctx) @@ -862,32 +973,6 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { } } -// completeAndFlushTrace completes the trace and forwards it to observability plugins. -// This is called either by the middleware defer (for non-streaming) or by streaming handlers. -func (m *TracingMiddleware) completeAndFlushTrace(traceID string) { - go func() { - // Clean up the stream accumulator for this trace - - // Get completed trace from store - completedTrace := m.tracer.Load().EndTrace(traceID) - if completedTrace == nil { - return - } - // Forward to all observability plugins - for _, plugin := range *m.obsPlugins.Load() { - if plugin == nil { - continue - } - // Call inject with a background context (request context is done) - if err := plugin.Inject(context.Background(), completedTrace); err != nil { - logger.Warn("observability plugin %s failed to inject trace: %v", plugin.GetName(), err) - } - } - // Return trace to pool for reuse - m.tracer.Load().ReleaseTrace(completedTrace) - }() -} - // GetTracer returns the tracer instance for use by streaming handlers func (m *TracingMiddleware) GetTracer() *tracing.Tracer { return m.tracer.Load() diff --git a/transports/bifrost-http/handlers/middlewares_test.go b/transports/bifrost-http/handlers/middlewares_test.go index 8e61a975be..9f951b876b 100644 --- a/transports/bifrost-http/handlers/middlewares_test.go +++ b/transports/bifrost-http/handlers/middlewares_test.go @@ -71,7 +71,7 @@ func TestCorsMiddleware_LocalhostOrigins(t *testing.T) { if string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")) != "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" { t.Errorf("Access-Control-Allow-Methods header not set correctly") } - if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key" { + if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK" { t.Errorf("Access-Control-Allow-Headers header not set correctly") } if string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")) != "true" { @@ -410,6 +410,69 @@ func TestChainMiddlewares_MiddlewareCanModifyContext(t *testing.T) { chained(ctx) } +func TestIsInferenceWSEndpoint(t *testing.T) { + paths := []string{ + "/v1/responses", + "/v1/realtime", + "/responses", + "/realtime", + "/openai/v1/responses", + "/openai/responses", + "/openai/openai/responses", + "/openai/v1/realtime", + "/openai/realtime", + "/openai/openai/realtime", + } + + for _, path := range paths { + if !isInferenceWSEndpoint(path) { + t.Fatalf("expected inference websocket path %s to be recognized", path) + } + } + + if isInferenceWSEndpoint("/api/ws") { + t.Fatal("dashboard websocket path should not be treated as inference websocket") + } + if isInferenceWSEndpoint("/openai/chat/completions") { + t.Fatal("non-websocket OpenAI path should not be treated as inference websocket") + } +} + +func TestIsRealtimeTransportEndpoint(t *testing.T) { + paths := []string{ + "/v1/realtime", + "/realtime", + "/openai/realtime", + "/openai/v1/realtime", + "/openai/openai/realtime", + "/v1/realtime/calls", + "/realtime/calls", + "/openai/realtime/calls", + "/openai/v1/realtime/calls", + "/openai/openai/realtime/calls", + } + + for _, path := range paths { + if !isRealtimeTransportEndpoint(path) { + t.Fatalf("expected realtime transport path %s to be recognized", path) + } + } + + nonTransportPaths := []string{ + "/v1/realtime/client_secrets", + "/v1/realtime/sessions", + "/openai/v1/realtime/client_secrets", + "/openai/v1/realtime/sessions", + "/v1/chat/completions", + } + + for _, path := range nonTransportPaths { + if isRealtimeTransportEndpoint(path) { + t.Fatalf("did not expect non-transport path %s to be recognized", path) + } + } +} + // Testlib.ChainMiddlewares_ShortCircuit tests that when a middleware writes a response // and does not call next, subsequent middlewares and handler do not execute. func TestChainMiddlewares_ShortCircuit(t *testing.T) { @@ -663,6 +726,83 @@ func TestAuthMiddleware_WhitelistedRoutes(t *testing.T) { } } +func TestAuthMiddleware_InferenceMiddleware_RealtimeTransportBypassesAuth(t *testing.T) { + SetLogger(&mockLogger{}) + + am := &AuthMiddleware{} + am.UpdateAuthConfig(&configstore.AuthConfig{ + AdminUserName: schemas.NewEnvVar("admin"), + AdminPassword: schemas.NewEnvVar("hashedpassword"), + IsEnabled: true, + }) + + routes := []string{ + "/v1/realtime", + "/openai/v1/realtime", + "/v1/realtime/calls?model=gpt-realtime", + "/openai/v1/realtime/calls?model=gpt-realtime", + } + + for _, route := range routes { + t.Run(route, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI(route) + + nextCalled := false + next := func(ctx *fasthttp.RequestCtx) { + nextCalled = true + } + + handler := am.InferenceMiddleware()(next) + handler(ctx) + + if !nextCalled { + t.Fatalf("expected realtime transport route %s to bypass auth", route) + } + }) + } +} + +func TestAuthMiddleware_InferenceMiddleware_RealtimeMintingStillRequiresAuth(t *testing.T) { + SetLogger(&mockLogger{}) + + am := &AuthMiddleware{} + am.UpdateAuthConfig(&configstore.AuthConfig{ + AdminUserName: schemas.NewEnvVar("admin"), + AdminPassword: schemas.NewEnvVar("hashedpassword"), + IsEnabled: true, + }) + + routes := []string{ + "/v1/realtime/client_secrets", + "/v1/realtime/sessions", + "/openai/v1/realtime/client_secrets", + "/openai/v1/realtime/sessions", + } + + for _, route := range routes { + t.Run(route, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI(route) + + nextCalled := false + next := func(ctx *fasthttp.RequestCtx) { + nextCalled = true + } + + handler := am.InferenceMiddleware()(next) + handler(ctx) + + if nextCalled { + t.Fatalf("expected realtime minting route %s to still require auth", route) + } + if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized { + t.Fatalf("expected %d for route %s, got %d", fasthttp.StatusUnauthorized, route, ctx.Response.StatusCode()) + } + }) + } +} + // TestAuthMiddleware_UpdateAuthConfig_NilToEnabled tests updating auth config from nil to enabled func TestAuthMiddleware_UpdateAuthConfig_NilToEnabled(t *testing.T) { SetLogger(&mockLogger{}) @@ -864,7 +1004,7 @@ func TestCorsMiddleware_DefaultHeaders(t *testing.T) { handler(ctx) // Check default headers are set - expectedHeaders := "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key" + expectedHeaders := "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK" actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) if actualHeaders != expectedHeaders { t.Errorf("Expected Access-Control-Allow-Headers to be %s, got %s", expectedHeaders, actualHeaders) diff --git a/transports/bifrost-http/handlers/oauth2.go b/transports/bifrost-http/handlers/oauth2.go index a7ee470e15..5240b32b91 100644 --- a/transports/bifrost-http/handlers/oauth2.go +++ b/transports/bifrost-http/handlers/oauth2.go @@ -5,8 +5,11 @@ package handlers import ( "context" "encoding/json" + "errors" "fmt" "html" + "net/url" + "strings" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" @@ -59,7 +62,39 @@ func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) { return } - // Complete OAuth flow + // Try per-user OAuth runtime flow first (state from oauth_user_sessions table). + // This handles the case where an end-user authenticates during inference. + sessionToken, perUserErr := h.oauthProvider.CompleteUserOAuthFlow(context.Background(), state, code) + if perUserErr != nil && !errors.Is(perUserErr, schemas.ErrOAuth2NotPerUserSession) { + // Real per-user error (not "state not found") β€” don't fall through to admin flow + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Per-user OAuth flow failed: %v", perUserErr)) + return + } + if perUserErr == nil && sessionToken != "" { + // Consent flow: session token is a flow proxy ("flow::"). + // Redirect back to the MCPs consent page so the user can continue. + if strings.HasPrefix(sessionToken, "flow:") { + rest := strings.TrimPrefix(sessionToken, "flow:") + flowID := strings.SplitN(rest, ":", 2)[0] + mcpsURL := fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)) + ctx.Redirect(mcpsURL, fasthttp.StatusFound) + return + } + + // Per-user runtime OAuth flow completed β€” show success page. + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("text/html") + ctx.SetBodyString(oauthSuccessPage(` + if (window.opener) { + window.opener.postMessage({ type: 'oauth_success' }, window.location.origin); + window.close(); + } + `, "Authorization Successful", "You can close this tab.")) + return + } + + // Fall through to standard OAuth flow (handles both admin test logins for + // per_user_oauth setup and regular server-level OAuth). if err := h.oauthProvider.CompleteOAuthFlow(context.Background(), state, code); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("OAuth flow completion failed: %v", err)) return @@ -68,31 +103,12 @@ func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) { // Redirect to success page (or close popup) ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("text/html") - ctx.SetBodyString(` - - - - OAuth Success - - - -
-
-

βœ“ Authorization Successful

-

This window will close automatically...

-
-
- - - `) + ctx.SetBodyString(oauthSuccessPage(` + if (window.opener) { + window.opener.postMessage({ type: 'oauth_success' }, window.location.origin); + window.close(); + } + `, "Authorization Successful", "OAuth authorization successful! You can close this window.")) } // handleCallbackError handles OAuth callback errors @@ -117,30 +133,7 @@ func (h *OAuthHandler) handleCallbackError(ctx *fasthttp.RequestCtx, state, erro jsEscaped, _ := json.Marshal(errorMsg) // HTML-escape for safe embedding in HTML body (prevents HTML injection) htmlEscaped := html.EscapeString(errorMsg) - ctx.SetBodyString(fmt.Sprintf(` - - - - OAuth Failed - - - -
-
-

βœ— Authorization Failed

-

%s

-

You can close this window.

-
-
- - - `, jsEscaped, htmlEscaped)) + ctx.SetBodyString(oauthErrorPage(string(jsEscaped), htmlEscaped)) } // getOAuthConfigStatus returns the current status of an OAuth config @@ -245,7 +238,83 @@ func (h *OAuthHandler) GetPendingMCPClientByState(state string) (*schemas.MCPCli return h.oauthProvider.GetPendingMCPClientByState(state) } -// RemovePendingMCPClient removes a pending MCP client after OAuth completion +// RemovePendingMCPClient removes a pending MCP client after OAuth completion. func (h *OAuthHandler) RemovePendingMCPClient(oauthConfigID string) error { return h.oauthProvider.RemovePendingMCPClient(oauthConfigID) } + +// GetAccessToken retrieves the access token for a given oauth_config_id. +// Used during per-user OAuth setup to get the admin's temporary token for verification. +func (h *OAuthHandler) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) { + return h.oauthProvider.GetAccessToken(ctx, oauthConfigID) +} + +// oauthSuccessPage renders a Bifrost-themed success HTML page. +// extraScript is injected verbatim into a + + +
+
+

%s

+

%s

+
+ +`, html.EscapeString(title), bifrostPageCSS, extraScript, html.EscapeString(title), html.EscapeString(message)) +} + +// oauthErrorPage renders a Bifrost-themed error HTML page. +// jsEscapedError must already be JSON-encoded (with quotes) for safe JS embedding. +// htmlError must already be HTML-escaped for safe body embedding. +func oauthErrorPage(jsEscapedError, htmlError string) string { + return fmt.Sprintf(` + + + + +Authorization Failed + + + + +
+
+

Authorization Failed

+

%s

+

You can close this window.

+
+ +`, bifrostPageCSS, jsEscapedError, htmlError) +} + +// jsEscapeString returns a JSON-encoded string (with quotes) safe for embedding in JavaScript. +func jsEscapeString(s string) string { + b, _ := json.Marshal(s) + return string(b) +} + +// RevokeToken revokes the OAuth token for a given oauth_config_id. +// Used during per-user OAuth setup to discard the admin's temporary token after verification. +func (h *OAuthHandler) RevokeToken(ctx context.Context, oauthConfigID string) error { + return h.oauthProvider.RevokeToken(ctx, oauthConfigID) +} diff --git a/transports/bifrost-http/handlers/oauth2_consent.go b/transports/bifrost-http/handlers/oauth2_consent.go new file mode 100644 index 0000000000..f1402d9d51 --- /dev/null +++ b/transports/bifrost-http/handlers/oauth2_consent.go @@ -0,0 +1,641 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file implements the per-user OAuth consent flow β€” the intermediate screens +// shown between the MCP client's authorize request and the final authorization code +// issuance. The flow is: +// +// 1. GET /oauth/consent?flow_id=xxx β†’ VK input page (HTML) +// 2. POST /api/oauth/per-user/consent/vk β†’ validate VK, update PendingFlow, redirect +// 3. GET /oauth/consent/mcps?flow_id=xxx β†’ MCPs page (HTML, server-rendered) +// 4. POST /api/oauth/per-user/consent/submit β†’ create session + code, redirect to client +package handlers + +import ( + "errors" + "fmt" + "html" + "net/url" + "sort" + "strings" + "time" + + "github.com/fasthttp/router" + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ConsentHandler manages the per-user OAuth consent flow screens. +type ConsentHandler struct { + store *lib.Config +} + +// NewConsentHandler creates a new consent handler instance. +func NewConsentHandler(store *lib.Config) *ConsentHandler { + return &ConsentHandler{store: store} +} + +// RegisterRoutes registers the consent flow routes. +// All routes are public β€” no auth middleware β€” since they are part of the OAuth +// flow for unauthenticated users acquiring credentials. +func (h *ConsentHandler) RegisterRoutes(r *router.Router) { + // HTML pages (GET, served by Go) + r.GET("/oauth/consent", h.handleIdentityPage) + r.GET("/oauth/consent/mcps", h.handleMCPsPage) + + // API actions (POST) + // NOTE: All state-mutating endpoints use POST. CSRF protection relies on the + // SameSite=Lax browser-binding cookie (__bifrost_flow_secret) combined with + // the flow_id β€” SameSite=Lax blocks cross-site POST, and the cookie is + // HttpOnly+Secure. This is sufficient for the threat model here. + r.POST("/api/oauth/per-user/consent/vk", h.handleSubmitVK) + r.POST("/api/oauth/per-user/consent/user-id", h.handleSubmitUserID) + r.POST("/api/oauth/per-user/consent/skip", h.handleSkip) + r.POST("/api/oauth/per-user/consent/submit", h.handleSubmit) +} + +// ---------- HTML pages ---------- + +// handleIdentityPage renders the identity selection page with three options: +// User ID, Virtual Key, or skip (lazy auth when tools are called). +// GET /oauth/consent?flow_id=xxx[&error=xxx] +func (h *ConsentHandler) handleIdentityPage(ctx *fasthttp.RequestCtx) { + flowID := string(ctx.QueryArgs().Peek("flow_id")) + errorMsg := string(ctx.QueryArgs().Peek("error")) + + if flowID == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Missing flow_id") + return + } + + if h.store.ConfigStore == nil { + ctx.SetStatusCode(fasthttp.StatusServiceUnavailable) + ctx.SetBodyString("Config store unavailable") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString("Failed to load consent flow.") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.") + return + } + + h.store.Mu.RLock() + enforceVK := h.store.ClientConfig.EnforceAuthOnInference + h.store.Mu.RUnlock() + + safeFlowID := html.EscapeString(flowID) + safeError := html.EscapeString(errorMsg) + + errorBanner := "" + if safeError != "" { + errorBanner = fmt.Sprintf(`
%s
`, safeError) + } + + skipOption := "" + if !enforceVK { + skipOption = fmt.Sprintf(` +
+ Skip for now + Connect to services when a tool is called +
+ + +
+
`, safeFlowID) + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetBodyString(fmt.Sprintf(` + + + + +Connect to Bifrost + + + +
+

Connect to Bifrost

+

Choose how to identify yourself for this session.

+

This setup page expires in 15 minutes.

+ %s +
+ User ID + Use a stable identifier β€” access all available services +
+ + + + +
+
+
+ Virtual Key + Use a VK β€” access services within your key's limits +
+ + + + +
+
+ %s +
+ +`, bifrostPageCSS, errorBanner, safeFlowID, safeFlowID, skipOption)) +} + +// handleMCPsPage renders the MCP authentication list page. +// GET /oauth/consent/mcps?flow_id=xxx +func (h *ConsentHandler) handleMCPsPage(ctx *fasthttp.RequestCtx) { + flowID := string(ctx.QueryArgs().Peek("flow_id")) + + if flowID == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Missing flow_id") + return + } + + if h.store.ConfigStore == nil { + ctx.SetStatusCode(fasthttp.StatusServiceUnavailable) + ctx.SetBodyString("Config store unavailable") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString("Failed to load consent flow.") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.") + return + } + + // Find which MCP clients the user has already authed. + // Check both: tokens stored under the flow proxy (connected during this flow) + // and tokens already stored under the VK/user identity (connected in a prior flow). + completedTokens, err := h.store.ConfigStore.GetOauthUserTokensByGatewaySessionID(ctx, flowID) + if err != nil { + completedTokens = nil // non-fatal; just show no checkmarks + } + completedMCPs := make(map[string]bool, len(completedTokens)) + for _, t := range completedTokens { + completedMCPs[t.MCPClientID] = true + } + + // Per_user_oauth MCP clients visible to this identity β€” sorted for deterministic rendering. + // When a VK is set on the flow, only show clients that VK is allowed to use. + perUserClients := h.store.GetPerUserOAuthMCPClientsForVirtualKey(ctx, strVal(flow.VirtualKeyID)) + clientIDs := make([]string, 0, len(perUserClients)) + for id := range perUserClients { + clientIDs = append(clientIDs, id) + } + sort.Strings(clientIDs) + + safeFlowID := html.EscapeString(flowID) + + // Determine if user skipped identity selection. + isSkipped := strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" + + // Build MCP rows β€” only show connect buttons if user has an identity. + var mcpRows strings.Builder + if isSkipped { + mcpRows.WriteString(`

You skipped identity selection. Services will be connected when you first use their tools. Since no identity is attached, your connections will only persist as long as the service keeps the OAuth token active β€” they will not be remembered across sessions.

`) + } else { + for _, clientID := range clientIDs { + clientName := perUserClients[clientID] + safeName := html.EscapeString(clientName) + + // Also check if a token already exists under the user's identity (e.g. from a prior LLM gateway auth). + alreadyConnected := completedMCPs[clientID] + if !alreadyConnected && (strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "") { + existing, tokenErr := h.store.ConfigStore.GetOauthUserTokenByIdentity(ctx, strVal(flow.VirtualKeyID), strVal(flow.UserID), "", clientID) + if tokenErr != nil { + logger.Warn("[consent/mcps] failed to check existing token: mcp_client_id=%s err=%v", clientID, tokenErr) + } + alreadyConnected = existing != nil + } + + if alreadyConnected { + mcpRows.WriteString(fmt.Sprintf(` +
+
%s
+ ✓ Connected +
`, safeName)) + } else { + connectURL := fmt.Sprintf("/api/oauth/per-user/upstream/authorize?mcp_client_id=%s&flow_id=%s", + url.QueryEscape(clientID), url.QueryEscape(flowID)) + mcpRows.WriteString(fmt.Sprintf(` +
+
%s
+ Connect +
`, safeName, html.EscapeString(connectURL))) + } + } + if len(perUserClients) == 0 { + mcpRows.WriteString(`

No MCP services require authentication.

`) + } + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetBodyString(fmt.Sprintf(` + + + + +Connect Your Apps β€” Bifrost + + + +
+

Connect Your Apps

+

Authenticate with the services below to enable their tools.

+

This setup page expires in 15 minutes.

+
%s
+
+ + +
+ +
+ +`, bifrostPageCSS, mcpRows.String(), safeFlowID, safeFlowID)) +} + +// ---------- API action handlers ---------- + +// handleSubmitVK validates the submitted Virtual Key, links it to the pending flow, +// and redirects to the MCPs page. +// POST /api/oauth/per-user/consent/vk (form: flow_id, vk) +func (h *ConsentHandler) handleSubmitVK(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + vkValue := strings.TrimSpace(string(ctx.FormValue("vk"))) + + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + if vkValue == "" { + redirectToIdentityPage(ctx, flowID, "Please enter a Virtual Key.") + return + } + + vk, err := h.store.ConfigStore.GetVirtualKeyByValue(ctx, vkValue) + if err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to validate Virtual Key. Please try again.") + return + } + if vk == nil || !vk.IsActive { + redirectToIdentityPage(ctx, flowID, "Virtual Key not found or inactive. Please check and try again.") + return + } + + flow.VirtualKeyID = &vk.ID + flow.UserID = nil // Clear other identity to keep selection exclusive + if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to save Virtual Key. Please try again.") + return + } + + ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound) +} + +// handleSubmitUserID links a user-supplied User ID to the pending flow and proceeds to MCPs page. +// SECURITY: The User ID is self-declared (typed in a form) with no server-side verification. +// This matches the trust model of X-Bf-User-Id in the LLM gateway path. Deployments requiring +// verified identity should use Virtual Keys or an auth layer in front of Bifrost. +// POST /api/oauth/per-user/consent/user-id (form: flow_id, user_id) +func (h *ConsentHandler) handleSubmitUserID(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + userID := strings.TrimSpace(string(ctx.FormValue("user_id"))) + + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + if userID == "" { + redirectToIdentityPage(ctx, flowID, "Please enter a User ID.") + return + } + if len(userID) > 255 { + redirectToIdentityPage(ctx, flowID, "User ID is too long (max 255 characters).") + return + } + + if userID != "" { + flow.UserID = &userID + } + flow.VirtualKeyID = nil // Clear other identity to keep selection exclusive + if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to save User ID. Please try again.") + return + } + + ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound) +} + +// handleSkip skips identity selection and proceeds directly to the MCPs page. +// Upstream services will be connected lazily when tools are first called. +// POST /api/oauth/per-user/consent/skip (form: flow_id) +func (h *ConsentHandler) handleSkip(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + h.store.Mu.RLock() + enforceVK := h.store.ClientConfig.EnforceAuthOnInference + h.store.Mu.RUnlock() + + if enforceVK { + redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.") + return + } + + // Clear any previously selected identity so skip truly resets the flow. + if strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "" { + flow.VirtualKeyID = nil + flow.UserID = nil + if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to clear identity. Please try again.") + return + } + } + + // Skip goes straight to MCPs page; no identity means only lazy auth is available. + ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound) +} + +// handleSubmit finalises the consent flow: +// 1. Creates a real Bifrost session (TablePerUserOAuthSession) +// 2. Migrates upstream tokens from the flow proxy to the real session +// 3. Issues a TablePerUserOAuthCode +// 4. Deletes the PendingFlow +// 5. Redirects to the original MCP client callback URL with code + state +// +// POST /api/oauth/per-user/consent/submit (form: flow_id) +func (h *ConsentHandler) handleSubmit(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid consent flow") + return + } + if time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Consent flow has expired. Please restart the authentication process.") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + // Server-side enforcement: reject if identity is required but not provided. + h.store.Mu.RLock() + enforceAuth := h.store.ClientConfig.EnforceAuthOnInference + h.store.Mu.RUnlock() + if enforceAuth && strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" { + redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.") + return + } + + // 1. Generate session credentials. + accessToken, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate session token") + return + } + refreshToken, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate refresh token") + return + } + + session := &tables.TablePerUserOAuthSession{ + ID: uuid.New().String(), + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: flow.ClientID, + VirtualKeyID: flow.VirtualKeyID, + UserID: flow.UserID, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + // 2. Generate authorization code. + code, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate authorization code") + return + } + codeRecord := &tables.TablePerUserOAuthCode{ + ID: uuid.New().String(), + Code: code, + ClientID: flow.ClientID, + RedirectURI: flow.RedirectURI, + CodeChallenge: flow.CodeChallenge, + SessionID: session.ID, // Links token endpoint to this session so it can return the same access token + // Scopes intentionally omitted: the consent flow has no scope selection step. + ExpiresAt: time.Now().Add(5 * time.Minute), + } + + // 3. Atomically consume the pending flow, create session, and create auth code. + // If another concurrent request already consumed the flow, rowsAffected will be 0. + rowsAffected, err := h.store.ConfigStore.FinalizePerUserOAuthConsent(ctx, flowID, session, codeRecord) + if err != nil { + if errors.Is(err, schemas.ErrPerUserOAuthPendingFlowExpired) { + SendError(ctx, fasthttp.StatusGone, "Consent flow has expired. Please restart the authentication process.") + return + } + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to finalize consent flow") + return + } + if rowsAffected == 0 { + SendError(ctx, fasthttp.StatusConflict, "Consent flow has already been submitted") + return + } + logger.Debug("[consent/submit] session created: session_id=%s flow_id=%s", session.ID, flowID) + + // 4. Migrate upstream tokens from flow proxy sessions to real session (non-fatal). + if err := h.store.ConfigStore.TransferOauthUserTokensFromGatewaySession(ctx, flowID, accessToken, strVal(flow.VirtualKeyID), strVal(flow.UserID)); err != nil { + // Non-fatal: tokens can be re-acquired on first tool use. + logger.Warn("[consent/submit] failed to transfer upstream tokens: flow_id=%s err=%v", flowID, err) + } + + // 5. Redirect to MCP client callback with code + original state. + redirectURL, err := url.Parse(flow.RedirectURI) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Invalid redirect URI in pending flow") + return + } + q := redirectURL.Query() + q.Set("code", code) + if flow.State != "" { + q.Set("state", flow.State) + } + redirectURL.RawQuery = q.Encode() + + ctx.Redirect(redirectURL.String(), fasthttp.StatusFound) +} + +// ---------- helpers ---------- + +// bifrostPageCSS is the shared inline CSS for all Go-rendered consent/callback pages. +// It mirrors Bifrost's UI design tokens: teal primary, zinc palette, Geist font stack. +const bifrostPageCSS = ` + *,*::before,*::after{box-sizing:border-box;margin:0;padding:0} + body{font-family:"Geist",system-ui,-apple-system,sans-serif;font-size:0.95rem; + line-height:1.5;background:#f4f4f5;color:oklch(0.141 0.005 285.823); + display:flex;align-items:center;justify-content:center;min-height:100vh; + -webkit-font-smoothing:antialiased} + .card{background:#fff;border:1px solid oklch(0.92 0.004 286.32);border-radius:12px; + padding:40px;width:100%;max-width:480px} + h1{font-size:1.25rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:6px} + .subtitle{font-size:0.825rem;color:oklch(0.552 0.016 285.938);line-height:1.5;margin-bottom:24px} + label{display:block;font-size:0.825rem;font-weight:500;color:oklch(0.141 0.005 285.823);margin-bottom:5px} + input[type=text],input[type=password]{width:100%;padding:8px 12px;border:1px solid oklch(0.92 0.004 286.32); + border-radius:0.5rem;font-size:0.875rem;outline:none; + transition:border-color .15s,box-shadow .15s;margin-bottom:10px; + background:#fff;color:oklch(0.141 0.005 285.823)} + input[type=text]:focus,input[type=password]:focus{border-color:oklch(0.5081 0.1049 165.61); + box-shadow:0 0 0 3px oklch(0.5081 0.1049 165.61 / 0.15)} + .btn{display:block;width:100%;padding:9px 16px;border-radius:0.5rem;font-size:0.875rem; + font-weight:500;cursor:pointer;border:none;text-align:center;text-decoration:none; + transition:background .15s;font-family:inherit} + .btn-primary{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0)} + .btn-primary:hover{background:oklch(0.43 0.1049 165.61)} + .btn-ghost{background:transparent;border:1px solid oklch(0.92 0.004 286.32); + color:oklch(0.552 0.016 285.938);display:inline-block;width:auto;padding:8px 16px} + .btn-ghost:hover{background:#f4f4f5} + .error-banner{background:oklch(0.97 0.02 27);border:1px solid oklch(0.88 0.06 27); + border-radius:0.5rem;padding:12px 14px;margin-bottom:18px; + color:oklch(0.50 0.18 27);font-size:0.825rem} +` + +// redirectToIdentityPage redirects to the identity selection page with an error message. +func redirectToIdentityPage(ctx *fasthttp.RequestCtx, flowID, errorMsg string) { + u := fmt.Sprintf("/oauth/consent?flow_id=%s&error=%s", + url.QueryEscape(flowID), url.QueryEscape(errorMsg)) + ctx.Redirect(u, fasthttp.StatusFound) +} + +// strVal safely dereferences a *string, returning "" for nil. +func strVal(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/transports/bifrost-http/handlers/oauth2_metadata.go b/transports/bifrost-http/handlers/oauth2_metadata.go new file mode 100644 index 0000000000..2a764291e4 --- /dev/null +++ b/transports/bifrost-http/handlers/oauth2_metadata.go @@ -0,0 +1,93 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file implements OAuth 2.0 metadata discovery endpoints per RFC 9728 +// (Protected Resource Metadata) and RFC 8414 (Authorization Server Metadata). +// These endpoints enable MCP-spec-compliant clients (like Claude Code) to +// automatically discover Bifrost's OAuth configuration and authenticate. +package handlers + +import ( + "fmt" + + "github.com/fasthttp/router" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// OAuthMetadataHandler serves OAuth 2.0 discovery metadata endpoints. +// It provides the Protected Resource Metadata (RFC 9728) and Authorization +// Server Metadata (RFC 8414) that MCP clients use to discover how to +// authenticate with Bifrost's MCP server endpoint. +type OAuthMetadataHandler struct { + store *lib.Config +} + +// NewOAuthMetadataHandler creates a new OAuth metadata handler instance. +func NewOAuthMetadataHandler(store *lib.Config) *OAuthMetadataHandler { + return &OAuthMetadataHandler{store: store} +} + +// RegisterRoutes registers the well-known metadata discovery routes. +// These routes do NOT go through auth middleware since they must be +// accessible to unauthenticated clients during OAuth discovery. +func (h *OAuthMetadataHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + // RFC 9728: Protected Resource Metadata + r.GET("/.well-known/oauth-protected-resource", lib.ChainMiddlewares(h.handleProtectedResourceMetadata, middlewares...)) + // RFC 8414: Authorization Server Metadata + r.GET("/.well-known/oauth-authorization-server", lib.ChainMiddlewares(h.handleAuthorizationServerMetadata, middlewares...)) +} + +// handleProtectedResourceMetadata serves the Protected Resource Metadata +// document per RFC 9728. MCP clients fetch this after receiving a 401 response +// to discover which authorization server(s) protect the MCP resource. +// +// GET /.well-known/oauth-protected-resource +func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.RequestCtx) { + if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not Found") + return + } + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" + } + host := string(ctx.Host()) + baseURL := fmt.Sprintf("%s://%s", scheme, host) + + SendJSON(ctx, map[string]interface{}{ + "resource": baseURL + "/mcp", + "authorization_servers": []string{baseURL}, + "scopes_supported": []string{"mcp:read", "mcp:write"}, + "bearer_methods_supported": []string{"header"}, + }) +} + +// handleAuthorizationServerMetadata serves the Authorization Server Metadata +// document per RFC 8414. MCP clients use this to discover Bifrost's OAuth +// endpoints (authorize, token, register) and supported capabilities. +// +// GET /.well-known/oauth-authorization-server +func (h *OAuthMetadataHandler) handleAuthorizationServerMetadata(ctx *fasthttp.RequestCtx) { + if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not Found") + return + } + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" + } + host := string(ctx.Host()) + baseURL := fmt.Sprintf("%s://%s", scheme, host) + + SendJSON(ctx, map[string]interface{}{ + "issuer": baseURL, + "authorization_endpoint": baseURL + "/api/oauth/per-user/authorize", + "token_endpoint": baseURL + "/api/oauth/per-user/token", + "registration_endpoint": baseURL + "/api/oauth/per-user/register", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "code_challenge_methods_supported": []string{"S256"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + "scopes_supported": []string{"mcp:read", "mcp:write"}, + }) +} diff --git a/transports/bifrost-http/handlers/oauth2_per_user.go b/transports/bifrost-http/handlers/oauth2_per_user.go new file mode 100644 index 0000000000..3c6f59041b --- /dev/null +++ b/transports/bifrost-http/handlers/oauth2_per_user.go @@ -0,0 +1,566 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file implements Bifrost's OAuth 2.1 Authorization Server for per-user MCP +// authentication. It provides Dynamic Client Registration (RFC 7591), Authorization +// Code flow with PKCE, and token issuance. MCP clients (Claude Code, IDEs) use +// these endpoints to authenticate users before accessing Bifrost's /mcp endpoint. +package handlers + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "net/url" + "strings" + "time" + + "github.com/fasthttp/router" + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// PerUserOAuthHandler implements Bifrost's OAuth 2.1 Authorization Server. +// It handles dynamic client registration, authorization code issuance with PKCE, +// and token exchange for MCP per-user authentication. +type PerUserOAuthHandler struct { + store *lib.Config +} + +// NewPerUserOAuthHandler creates a new per-user OAuth handler instance. +func NewPerUserOAuthHandler(store *lib.Config) *PerUserOAuthHandler { + return &PerUserOAuthHandler{store: store} +} + +// RegisterRoutes registers the per-user OAuth authorization server routes. +// These routes do NOT go through auth middleware since they are part of the +// OAuth flow that unauthenticated clients use to obtain tokens. +func (h *PerUserOAuthHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + r.POST("/api/oauth/per-user/register", lib.ChainMiddlewares(h.handleDynamicClientRegistration, middlewares...)) + r.GET("/api/oauth/per-user/authorize", lib.ChainMiddlewares(h.handleAuthorize, middlewares...)) + r.POST("/api/oauth/per-user/token", lib.ChainMiddlewares(h.handleToken, middlewares...)) + r.GET("/api/oauth/per-user/upstream/authorize", lib.ChainMiddlewares(h.handleUpstreamAuthorize, middlewares...)) +} + +// handleDynamicClientRegistration handles OAuth 2.0 Dynamic Client Registration +// per RFC 7591. MCP clients register themselves to obtain a client_id. +// +// POST /api/oauth/per-user/register +func (h *PerUserOAuthHandler) handleDynamicClientRegistration(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth registration unavailable: config store is disabled") + return + } + + if len(h.store.GetPerUserOAuthMCPClients()) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not found") + return + } + + var req struct { + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + Scope string `json:"scope"` + } + + if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid registration request: %v", err)) + return + } + + if len(req.RedirectURIs) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "redirect_uris is required") + return + } + + // Generate client_id + clientID := uuid.New().String() + + // Serialize arrays + redirectURIsJSON, _ := json.Marshal(req.RedirectURIs) + grantTypes := req.GrantTypes + if len(grantTypes) == 0 { + grantTypes = []string{"authorization_code"} + } + grantTypesJSON, _ := json.Marshal(grantTypes) + + client := &tables.TablePerUserOAuthClient{ + ID: uuid.New().String(), + ClientID: clientID, + ClientName: req.ClientName, + RedirectURIs: string(redirectURIsJSON), + GrantTypes: string(grantTypesJSON), + } + + if err := h.store.ConfigStore.CreatePerUserOAuthClient(ctx, client); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to register client: %v", err)) + return + } + + // Return RFC 7591 response + ctx.SetStatusCode(fasthttp.StatusCreated) + SendJSON(ctx, map[string]interface{}{ + "client_id": clientID, + "client_name": req.ClientName, + "redirect_uris": req.RedirectURIs, + "grant_types": grantTypes, + "response_types": req.ResponseTypes, + "token_endpoint_auth_method": "none", + }) +} + +// handleAuthorize handles the OAuth 2.1 authorization endpoint. +// Instead of issuing a code immediately, it validates the request parameters, +// creates a PendingFlow record, and redirects the user to the consent screen. +// The code is only issued after the user completes the consent flow (VK + MCP auths). +// +// GET /api/oauth/per-user/authorize?response_type=code&client_id=xxx&redirect_uri=xxx&code_challenge=xxx&code_challenge_method=S256[&state=xxx] +func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth authorization unavailable: config store is disabled") + return + } + + if len(h.store.GetPerUserOAuthMCPClients()) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not found") + return + } + + // Extract parameters + responseType := string(ctx.QueryArgs().Peek("response_type")) + clientID := string(ctx.QueryArgs().Peek("client_id")) + redirectURI := string(ctx.QueryArgs().Peek("redirect_uri")) + state := string(ctx.QueryArgs().Peek("state")) + codeChallenge := string(ctx.QueryArgs().Peek("code_challenge")) + codeChallengeMethod := string(ctx.QueryArgs().Peek("code_challenge_method")) + + // Validate required parameters + if responseType != "code" { + SendError(ctx, fasthttp.StatusBadRequest, "response_type must be 'code'") + return + } + if clientID == "" || redirectURI == "" { + SendError(ctx, fasthttp.StatusBadRequest, "client_id and redirect_uri are required") + return + } + if codeChallenge == "" || codeChallengeMethod != "S256" { + SendError(ctx, fasthttp.StatusBadRequest, "PKCE is required: code_challenge and code_challenge_method=S256") + return + } + + // Validate client exists and redirect_uri is registered + client, err := h.store.ConfigStore.GetPerUserOAuthClientByClientID(ctx, clientID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to validate client: %v", err)) + return + } + if client == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Unknown client_id") + return + } + var allowedURIs []string + json.Unmarshal([]byte(client.RedirectURIs), &allowedURIs) + uriAllowed := false + for _, allowed := range allowedURIs { + if allowed == redirectURI { + uriAllowed = true + break + } + } + if !uriAllowed { + SendError(ctx, fasthttp.StatusBadRequest, "redirect_uri not registered for this client") + return + } + + // Generate a browser-binding secret so only the initiating browser can resume this flow. + browserSecret, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate browser secret") + return + } + browserSecretHash := fmt.Sprintf("%x", sha256.Sum256([]byte(browserSecret))) + + // Create a PendingFlow to carry OAuth params through the consent screen. + flow := &tables.TablePerUserOAuthPendingFlow{ + ID: uuid.New().String(), + ClientID: clientID, + RedirectURI: redirectURI, + CodeChallenge: codeChallenge, + State: state, + BrowserSecretHash: browserSecretHash, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + if err := h.store.ConfigStore.CreatePerUserOAuthPendingFlow(ctx, flow); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create pending flow: %v", err)) + return + } + logger.Debug("[oauth/authorize] PendingFlow created: flow_id=%s client_id=%s", flow.ID, clientID) + + // Set HttpOnly cookie binding this flow to the current browser. + var cookie fasthttp.Cookie + cookie.SetKey("__bifrost_flow_secret") + cookie.SetValue(browserSecret) + cookie.SetPath("/") + cookie.SetHTTPOnly(true) + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + isSecure := ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" + cookie.SetSecure(isSecure) + cookie.SetMaxAge(15 * 60) // 15 minutes, matching flow TTL + ctx.Response.Header.SetCookie(&cookie) + + // Redirect to consent screen with flow_id (relative path β€” stays on current origin). + consentURL := fmt.Sprintf("/oauth/consent?flow_id=%s", url.QueryEscape(flow.ID)) + ctx.Redirect(consentURL, fasthttp.StatusFound) +} + +// handleToken handles the OAuth 2.1 token endpoint. +// It validates the authorization code + PKCE verifier and issues access/refresh tokens. +// +// POST /api/oauth/per-user/token +func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth token endpoint unavailable: config store is disabled") + return + } + + if len(h.store.GetPerUserOAuthMCPClients()) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not found") + return + } + + // Parse form-encoded body + grantType := string(ctx.FormValue("grant_type")) + code := string(ctx.FormValue("code")) + redirectURI := string(ctx.FormValue("redirect_uri")) + clientID := string(ctx.FormValue("client_id")) + codeVerifier := string(ctx.FormValue("code_verifier")) + + if grantType != "authorization_code" { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "unsupported_grant_type", "Only authorization_code grant is supported") + return + } + + if code == "" || codeVerifier == "" { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_request", "code and code_verifier are required") + return + } + + // Atomically claim authorization code (prevents concurrent redemption) + codeRecord, err := h.store.ConfigStore.ClaimPerUserOAuthCode(ctx, code) + if err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to validate code") + return + } + if codeRecord == nil { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Invalid or already used authorization code") + return + } + + // Validate code is not expired + if time.Now().After(codeRecord.ExpiresAt) { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Authorization code expired") + return + } + + // Validate client_id if provided β€” some public clients omit it (RFC 6749 Β§4.1.3 allows + // omitting client_id when the client is not authenticating with the server). + // The code record already binds the code to the correct client, so this is safe. + if clientID != "" && codeRecord.ClientID != clientID { + logger.Debug("[oauth/token] client_id mismatch: code_client=%s request_client=%s", codeRecord.ClientID, clientID) + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "client_id mismatch") + return + } + // Use the client_id from the code record as the authoritative value. + clientID = codeRecord.ClientID + + // Validate redirect_uri matches + if redirectURI != "" && codeRecord.RedirectURI != redirectURI { + logger.Debug("[oauth/token] redirect_uri mismatch: code=%s request=%s", codeRecord.RedirectURI, redirectURI) + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") + return + } + + // Validate PKCE: SHA256(code_verifier) must match code_challenge + verifierHash := sha256.Sum256([]byte(codeVerifier)) + computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:]) + if computedChallenge != codeRecord.CodeChallenge { + logger.Debug("[oauth/token] PKCE verification failed") + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "PKCE verification failed") + return + } + + // If the code was issued by the consent flow (handleSubmit), the session already exists + // with the upstream tokens transferred to it. Reuse that session's access token so the + // client receives the token that the upstream (Notion, GitHub, etc.) tokens are linked to. + var accessToken string + var expiresAt time.Time + + if codeRecord.SessionID != "" { + existingSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, codeRecord.SessionID) + if err != nil { + logger.Info("[oauth/token] Failed to load existing session: session_id=%s err=%v", codeRecord.SessionID, err) + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to load session") + return + } + if existingSession == nil { + logger.Info("[oauth/token] Existing session not found: session_id=%s", codeRecord.SessionID) + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Session not found") + return + } + if !existingSession.ExpiresAt.After(time.Now()) { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Session expired") + return + } + accessToken = existingSession.AccessToken + expiresAt = existingSession.ExpiresAt + logger.Debug("[oauth/token] reusing consent session: session_id=%s", existingSession.ID) + } else { + // Fallback: no linked session (legacy path) β€” create a new one. + var newAccessToken, newRefreshToken string + newAccessToken, err = generateOpaqueToken(32) + if err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate access token") + return + } + newRefreshToken, err = generateOpaqueToken(32) + if err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate refresh token") + return + } + expiresAt = time.Now().Add(24 * time.Hour) + newSession := &tables.TablePerUserOAuthSession{ + ID: uuid.New().String(), + AccessToken: newAccessToken, + RefreshToken: newRefreshToken, + ClientID: clientID, + ExpiresAt: expiresAt, + } + if err := h.store.ConfigStore.CreatePerUserOAuthSession(ctx, newSession); err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to create session") + return + } + accessToken = newAccessToken + logger.Debug("[oauth/token] created new session (legacy path): session_id=%s", newSession.ID) + } + // Return OAuth token response + ctx.SetContentType("application/json") + ctx.SetStatusCode(fasthttp.StatusOK) + SendJSON(ctx, map[string]interface{}{ + "access_token": accessToken, + "token_type": "Bearer", + "expires_in": int(time.Until(expiresAt).Seconds()), + "scope": codeRecord.Scopes, + }) +} + +// sendOAuthError sends an OAuth 2.0 error response per RFC 6749 Section 5.2. +func sendOAuthError(ctx *fasthttp.RequestCtx, statusCode int, errorCode, description string) { + ctx.SetContentType("application/json") + ctx.SetStatusCode(statusCode) + resp, _ := json.Marshal(map[string]string{ + "error": errorCode, + "error_description": description, + }) + ctx.SetBody(resp) +} + +func sendStringError(ctx *fasthttp.RequestCtx, statusCode int, message string) { + ctx.SetContentType("text/plain") + ctx.SetStatusCode(statusCode) + ctx.SetBodyString(message) +} + +// generateOpaqueToken generates a cryptographically secure random token. +// validateFlowBrowserSecret checks that the request carries the __bifrost_flow_secret +// cookie matching the hash stored on the pending flow. Returns true if valid. +func validateFlowBrowserSecret(ctx *fasthttp.RequestCtx, flow *tables.TablePerUserOAuthPendingFlow) bool { + if flow.BrowserSecretHash == "" { + // Legacy flow without browser binding β€” allow for backwards compatibility. + return true + } + secret := ctx.Request.Header.Cookie("__bifrost_flow_secret") + if len(secret) == 0 { + return false + } + hash := fmt.Sprintf("%x", sha256.Sum256(secret)) + return hash == flow.BrowserSecretHash +} + +func generateOpaqueToken(length int) (string, error) { + bytes := make([]byte, length) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} + +// handleUpstreamAuthorize handles the upstream OAuth proxy for per-user OAuth. +// When a user needs to authenticate with an upstream MCP server (e.g., Notion), +// this endpoint redirects them to the upstream provider's OAuth authorize URL. +// After the user authenticates, the callback stores their upstream token linked +// to either their Bifrost session (runtime flow) or a PendingFlow (consent flow). +// +// Runtime flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&session=xxx +// Consent flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&flow_id=xxx +func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth upstream authorization unavailable: config store is disabled") + return + } + + mcpClientID := string(ctx.QueryArgs().Peek("mcp_client_id")) + sessionID := string(ctx.QueryArgs().Peek("session")) + flowID := string(ctx.QueryArgs().Peek("flow_id")) + + if mcpClientID == "" || (sessionID == "" && flowID == "") { + SendError(ctx, fasthttp.StatusBadRequest, "mcp_client_id and either session or flow_id are required") + return + } + + // Resolve identity depending on whether this is a runtime session or a consent flow. + var virtualKeyID, userID, proxySessionToken, gatewaySessionID string + if flowID != "" { + // Consent flow: use the pending flow for identity and proxy token. + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil || flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + if strVal(flow.VirtualKeyID) != "" { + virtualKeyID = *flow.VirtualKeyID + } + if strVal(flow.UserID) != "" { + userID = *flow.UserID + } + // Use a prefixed flow token so the callback can detect the consent path. + // Include mcpClientID to avoid unique constraint violations when multiple + // MCP services are connected in the same consent flow. + proxySessionToken = "flow:" + flowID + ":" + mcpClientID + gatewaySessionID = flowID + } else { + // Runtime flow: validate the existing Bifrost session. + bifrostSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID) + if err != nil || bifrostSession == nil { + SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session") + return + } + if !bifrostSession.ExpiresAt.After(time.Now()) { + SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session") + return + } + virtualKeyID = strVal(bifrostSession.VirtualKeyID) + userID = strVal(bifrostSession.UserID) + proxySessionToken = "runtime:" + sessionID + ":" + mcpClientID + gatewaySessionID = sessionID + } + + // Look up the MCP client config to get the template OAuth config. + mcpClient, err := h.store.ConfigStore.GetMCPClientByID(ctx, mcpClientID) + if err != nil || mcpClient == nil { + SendError(ctx, fasthttp.StatusNotFound, "MCP client not found") + return + } + if mcpClient.AuthType != string(schemas.MCPAuthTypePerUserOauth) { + SendError(ctx, fasthttp.StatusBadRequest, "MCP client does not use per-user OAuth") + return + } + if mcpClient.OauthConfigID == nil || *mcpClient.OauthConfigID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "MCP client has no OAuth configuration") + return + } + + // Load template OAuth config (has upstream authorize_url, client_id, etc.) + templateConfig, err := h.store.ConfigStore.GetOauthConfigByID(ctx, *mcpClient.OauthConfigID) + if err != nil || templateConfig == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load OAuth template config") + return + } + + // Generate PKCE challenge for upstream. + codeVerifier, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate PKCE verifier") + return + } + verifierHash := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:]) + + // Generate state for upstream. + state, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate state token") + return + } + + // Build redirect URI (Bifrost's callback endpoint). + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" + } + host := string(ctx.Host()) + redirectURI := fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host) + var vkId *string + if virtualKeyID != "" { + vkId = &virtualKeyID + } + var uid *string + if userID != "" { + uid = &userID + } + // Store upstream OAuth session linking state β†’ MCP client + identity. + upstreamSession := &tables.TableOauthUserSession{ + ID: uuid.New().String(), + MCPClientID: mcpClientID, + OauthConfigID: *mcpClient.OauthConfigID, + State: state, + CodeVerifier: codeVerifier, + SessionToken: proxySessionToken, // "runtime:xxx" for runtime flow; "flow:xxx" for consent flow + GatewaySessionID: gatewaySessionID, + VirtualKeyID: vkId, + UserID: uid, + Status: "pending", + ExpiresAt: time.Now().Add(15 * time.Minute), + } + logger.Debug("[oauth/upstream-authorize] creating upstream session: mcp_client=%s flow=%s", mcpClientID, proxySessionToken) + if err := h.store.ConfigStore.CreateOauthUserSession(ctx, upstreamSession); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create upstream OAuth session: %v", err)) + return + } + + // Parse scopes from template config. + var scopes []string + if templateConfig.Scopes != "" { + json.Unmarshal([]byte(templateConfig.Scopes), &scopes) + } + + // Build upstream authorize URL with PKCE. + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", templateConfig.ClientID) + params.Set("redirect_uri", redirectURI) + params.Set("state", state) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + if len(scopes) > 0 { + params.Set("scope", strings.Join(scopes, " ")) + } + + upstreamAuthorizeURL := templateConfig.AuthorizeURL + "?" + params.Encode() + ctx.Redirect(upstreamAuthorizeURL, fasthttp.StatusFound) +} + +// Ensure unused imports are referenced. +var _ = html.EscapeString +var _ configstore.ConfigStore diff --git a/transports/bifrost-http/handlers/pricing_override_test.go b/transports/bifrost-http/handlers/pricing_override_test.go new file mode 100644 index 0000000000..4d19d0541e --- /dev/null +++ b/transports/bifrost-http/handlers/pricing_override_test.go @@ -0,0 +1,149 @@ +package handlers + +import ( + "context" + "encoding/json" + "net" + "os" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/framework/modelcatalog" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +type pricingOverrideTestGovernanceManager struct{} + +func (pricingOverrideTestGovernanceManager) GetGovernanceData() *governance.GovernanceData { + return nil +} +func (pricingOverrideTestGovernanceManager) ReloadVirtualKey(context.Context, string) (*configstoreTables.TableVirtualKey, error) { + return nil, nil +} +func (pricingOverrideTestGovernanceManager) RemoveVirtualKey(context.Context, string) error { + return nil +} +func (pricingOverrideTestGovernanceManager) ReloadTeam(context.Context, string) (*configstoreTables.TableTeam, error) { + return nil, nil +} +func (pricingOverrideTestGovernanceManager) RemoveTeam(context.Context, string) error { + return nil +} +func (pricingOverrideTestGovernanceManager) ReloadCustomer(context.Context, string) (*configstoreTables.TableCustomer, error) { + return nil, nil +} +func (pricingOverrideTestGovernanceManager) RemoveCustomer(context.Context, string) error { + return nil +} +func (pricingOverrideTestGovernanceManager) ReloadModelConfig(context.Context, string) (*configstoreTables.TableModelConfig, error) { + return nil, nil +} +func (pricingOverrideTestGovernanceManager) RemoveModelConfig(context.Context, string) error { + return nil +} +func (pricingOverrideTestGovernanceManager) ReloadProvider(context.Context, schemas.ModelProvider) (*configstoreTables.TableProvider, error) { + return nil, nil +} +func (pricingOverrideTestGovernanceManager) RemoveProvider(context.Context, schemas.ModelProvider) error { + return nil +} +func (pricingOverrideTestGovernanceManager) ReloadRoutingRule(context.Context, string) error { + return nil +} +func (pricingOverrideTestGovernanceManager) RemoveRoutingRule(context.Context, string) error { + return nil +} +func (pricingOverrideTestGovernanceManager) UpsertPricingOverride(context.Context, *configstoreTables.TablePricingOverride) error { + return nil +} +func (pricingOverrideTestGovernanceManager) DeletePricingOverride(context.Context, string) error { + return nil +} + +func setupPricingOverrideHandlerStore(t *testing.T) configstore.ConfigStore { + t.Helper() + + dbPath := t.TempDir() + "/config.db" + store, err := configstore.NewConfigStore(context.Background(), &configstore.Config{ + Enabled: true, + Type: configstore.ConfigStoreTypeSQLite, + Config: &configstore.SQLiteConfig{ + Path: dbPath, + }, + }, &mockLogger{}) + require.NoError(t, err) + + t.Cleanup(func() { + _ = os.Remove(dbPath) + }) + return store +} + +func newTestRequestCtx(body string) *fasthttp.RequestCtx { + var req fasthttp.Request + req.SetBodyString(body) + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}, nil) + return ctx +} + +func TestUpdatePricingOverride_ReplacesFullBody(t *testing.T) { + SetLogger(&mockLogger{}) + store := setupPricingOverrideHandlerStore(t) + handler := &GovernanceHandler{ + configStore: store, + governanceManager: pricingOverrideTestGovernanceManager{}, + } + + now := time.Now().UTC() + override := configstoreTables.TablePricingOverride{ + ID: "override-1", + Name: "Original", + ScopeKind: string(modelcatalog.ScopeKindGlobal), + MatchType: string(modelcatalog.MatchTypeExact), + Pattern: "gpt-4.1", + CreatedAt: now, + UpdatedAt: now, + PricingPatchJSON: `{"input_cost_per_token":1,"output_cost_per_token":2}`, + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, + } + require.NoError(t, store.CreatePricingOverride(context.Background(), &override)) + + // Patch replaces in full: send only input_cost_per_token. + // output_cost_per_token must be absent from the stored patch afterwards, + // confirming full-replace (not merge) semantics. + body := `{ + "name":"Updated", + "scope_kind":"global", + "match_type":"exact", + "pattern":"gpt-4.1", + "request_types":["chat_completion"], + "patch":{"input_cost_per_token":1.5} + }` + ctx := newTestRequestCtx(body) + ctx.SetUserValue("id", override.ID) + + handler.updatePricingOverride(ctx) + + require.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode(), string(ctx.Response.Body())) + + stored, err := store.GetPricingOverrideByID(context.Background(), override.ID) + require.NoError(t, err) + assert.Equal(t, "Updated", stored.Name) + + var patch modelcatalog.PricingOptions + require.NoError(t, json.Unmarshal([]byte(stored.PricingPatchJSON), &patch)) + // Sent field must reflect the new value. + require.NotNil(t, patch.InputCostPerToken) + assert.Equal(t, 1.5, *patch.InputCostPerToken) + // Omitted field must be cleared β€” patch is always fully replaced, not merged. + assert.Nil(t, patch.OutputCostPerToken) + assert.Empty(t, stored.ConfigHash) +} diff --git a/transports/bifrost-http/handlers/prompts.go b/transports/bifrost-http/handlers/prompts.go index 9f4ac48470..e5b96f0c38 100644 --- a/transports/bifrost-http/handlers/prompts.go +++ b/transports/bifrost-http/handlers/prompts.go @@ -1,8 +1,10 @@ package handlers import ( + "context" "encoding/json" "errors" + "fmt" "strconv" "github.com/fasthttp/router" @@ -14,15 +16,35 @@ import ( "github.com/valyala/fasthttp" ) +// PromptCacheReloader is implemented by the prompts plugin to allow the HTTP handler +// to trigger an in-memory cache refresh after any repository mutation. +type PromptCacheReloader interface { + Reload(ctx context.Context) error +} + // PromptsHandler handles prompt repository endpoints type PromptsHandler struct { - store configstore.ConfigStore + store configstore.ConfigStore + reloader PromptCacheReloader // optional; nil when the prompts plugin is not loaded +} + +// NewPromptsHandler creates a new PromptsHandler. +// reloader may be nil; when set, the in-memory prompt cache is refreshed after mutations. +func NewPromptsHandler(store configstore.ConfigStore, reloader PromptCacheReloader) *PromptsHandler { + if store == nil { + return nil + } + return &PromptsHandler{store: store, reloader: reloader} } -// NewPromptsHandler creates a new PromptsHandler -func NewPromptsHandler(store configstore.ConfigStore) *PromptsHandler { - return &PromptsHandler{ - store: store, +// reloadCache triggers a cache refresh if a reloader is configured. +// Errors are logged but do not fail the originating request. +func (h *PromptsHandler) reloadCache(ctx context.Context) { + if h.reloader == nil { + return + } + if err := h.reloader.Reload(ctx); err != nil { + logger.Error("failed to reload prompt cache: %v", err) } } @@ -143,7 +165,8 @@ type RenameSessionRequest struct { // CommitSessionRequest represents the request body for committing a session as a version type CommitSessionRequest struct { - CommitMessage string `json:"commit_message"` + CommitMessage string `json:"commit_message"` + MessageIndices *[]int `json:"message_indices,omitempty"` // optional: indices of messages to include (0-based). If nil/absent, all messages are included. } // ============================================================================ @@ -294,6 +317,7 @@ func (h *PromptsHandler) deleteFolder(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "message": "folder deleted successfully", }) @@ -392,6 +416,7 @@ func (h *PromptsHandler) createPrompt(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "prompt": prompt, }) @@ -465,6 +490,7 @@ func (h *PromptsHandler) updatePrompt(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "prompt": prompt, }) @@ -493,6 +519,7 @@ func (h *PromptsHandler) deletePrompt(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "message": "prompt deleted successfully", }) @@ -619,6 +646,7 @@ func (h *PromptsHandler) createVersion(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "version": version, }) @@ -652,6 +680,7 @@ func (h *PromptsHandler) deleteVersion(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "message": "version deleted successfully", }) @@ -1005,11 +1034,36 @@ func (h *PromptsHandler) commitSession(ctx *fasthttp.RequestCtx) { // Convert session messages to version messages var messages []tables.TablePromptVersionMessage - for _, msg := range session.Messages { - messages = append(messages, tables.TablePromptVersionMessage{ - PromptID: session.PromptID, - Message: msg.Message, - }) + if req.MessageIndices != nil { + // Only include messages at the specified indices, deduplicating + seen := make(map[int]struct{}) + for _, idx := range *req.MessageIndices { + if _, ok := seen[idx]; ok { + continue + } + seen[idx] = struct{}{} + if idx < 0 || idx >= len(session.Messages) { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("message index %d out of range (0-%d)", idx, len(session.Messages)-1)) + return + } + msg := session.Messages[idx] + messages = append(messages, tables.TablePromptVersionMessage{ + PromptID: session.PromptID, + Message: msg.Message, + }) + } + } else { + for _, msg := range session.Messages { + messages = append(messages, tables.TablePromptVersionMessage{ + PromptID: session.PromptID, + Message: msg.Message, + }) + } + } + + if len(messages) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "at least one message must be included in the version") + return } version := &tables.TablePromptVersion{ @@ -1027,6 +1081,7 @@ func (h *PromptsHandler) commitSession(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "version": version, }) diff --git a/transports/bifrost-http/handlers/provider_keys.go b/transports/bifrost-http/handlers/provider_keys.go new file mode 100644 index 0000000000..efd23d0bfb --- /dev/null +++ b/transports/bifrost-http/handlers/provider_keys.go @@ -0,0 +1,495 @@ +package handlers + +import ( + "errors" + "fmt" + "net/url" + + "github.com/bytedance/sonic" + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ListProviderKeysResponse represents the response for listing keys for a provider. +type ListProviderKeysResponse struct { + Keys []schemas.Key `json:"keys"` + Total int `json:"total"` +} + +func (h *ProviderHandler) listProviderKeys(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keys, err := h.inMemoryStore.GetProviderKeysRedacted(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider keys: %v", err)) + return + } + + SendJSON(ctx, ListProviderKeysResponse{Keys: keys, Total: len(keys)}) +} + +func (h *ProviderHandler) getProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keyID, err := getKeyIDFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + key, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + SendJSON(ctx, key) +} + +func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + var key schemas.Key + if err := sonic.Unmarshal(ctx.PostBody(), &key); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) + return + } + + providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err)) + return + } + + if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess { + SendError(ctx, fasthttp.StatusBadRequest, "Cannot add keys to a keyless provider") + return + } + + baseProvider := provider + if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" { + baseProvider = providerConfig.CustomProviderConfig.BaseProviderType + } + + if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && key.Value.GetValue() == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty") + return + } + + if err := validateProviderKeyURL(provider, key); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + if err := key.BlacklistedModels.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err)) + return + } + + if err := key.Aliases.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err)) + return + } + + if key.ID == "" { + key.ID = uuid.NewString() + } + if key.Enabled == nil { + key.Enabled = bifrost.Ptr(true) + } + + if err := h.inMemoryStore.AddProviderKey(ctx, provider, key); err != nil { + logger.Warn("Failed to create key for provider %s: %v", provider, err) + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + if errors.Is(err, lib.ErrAlreadyExists) { + SendError(ctx, fasthttp.StatusConflict, err.Error()) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create provider key: %v", err)) + return + } + + if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil { + logger.Warn("Model discovery failed for provider %s after key create: %v", provider, err) + } + + redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, key.ID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get created provider key: %v", err)) + return + } + + SendJSON(ctx, redactedKey) +} + +func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keyID, err := getKeyIDFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + var updateKey schemas.Key + if err := sonic.Unmarshal(ctx.PostBody(), &updateKey); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) + return + } + + providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err)) + return + } + + if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess { + SendError(ctx, fasthttp.StatusBadRequest, "Cannot update keys on a keyless provider") + return + } + + oldRawKey, err := h.inMemoryStore.GetProviderKeyRaw(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + oldRedactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + updateKey.ID = keyID + mergedKey := h.mergeUpdatedKey(*oldRawKey, *oldRedactedKey, updateKey) + + baseProvider := provider + if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.BaseProviderType != "" { + baseProvider = providerConfig.CustomProviderConfig.BaseProviderType + } + + if !bifrost.CanProviderKeyValueBeEmpty(baseProvider) && mergedKey.Value.GetValue() == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty") + return + } + + if err := mergedKey.BlacklistedModels.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid blacklisted_models: %v", err)) + return + } + + if err := mergedKey.Aliases.Validate(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid aliases: %v", err)) + return + } + + if err := validateProviderKeyURL(provider, mergedKey); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + if err := h.inMemoryStore.UpdateProviderKey(ctx, provider, keyID, mergedKey); err != nil { + logger.Warn("Failed to update key %s for provider %s: %v", keyID, provider, err) + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider key: %v", err)) + return + } + + if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil { + logger.Warn("Model discovery failed for provider %s after key update: %v", provider, err) + } + + redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get updated provider key: %v", err)) + return + } + + SendJSON(ctx, redactedKey) +} + +func (h *ProviderHandler) deleteProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keyID, err := getKeyIDFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err)) + return + } + + if providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess { + SendError(ctx, fasthttp.StatusBadRequest, "Cannot delete keys on a keyless provider") + return + } + + redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + if err := h.inMemoryStore.RemoveProviderKey(ctx, provider, keyID); err != nil { + logger.Warn("Failed to delete key %s for provider %s: %v", keyID, provider, err) + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to delete provider key: %v", err)) + return + } + + if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil { + logger.Warn("Model discovery failed for provider %s after key delete: %v", provider, err) + } + + SendJSON(ctx, redactedKey) +} + +// mergeUpdatedKey merges an updated key with the old raw/redacted versions, +// preserving real values for fields that were sent back in redacted form. +func (h *ProviderHandler) mergeUpdatedKey(oldRawKey, oldRedactedKey, updateKey schemas.Key) schemas.Key { + mergedKey := updateKey + + if updateKey.Value.IsRedacted() && updateKey.Value.Equals(&oldRedactedKey.Value) { + mergedKey.Value = oldRawKey.Value + } + + if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil { + if updateKey.AzureKeyConfig.Endpoint.IsRedacted() && + updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) { + mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint + } + if updateKey.AzureKeyConfig.APIVersion != nil && + oldRedactedKey.AzureKeyConfig.APIVersion != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.APIVersion.IsRedacted() && + updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) { + mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion + } + if updateKey.AzureKeyConfig.ClientID != nil && + oldRedactedKey.AzureKeyConfig.ClientID != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.ClientID.IsRedacted() && + updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) { + mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID + } + if updateKey.AzureKeyConfig.ClientSecret != nil && + oldRedactedKey.AzureKeyConfig.ClientSecret != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.ClientSecret.IsRedacted() && + updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) { + mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret + } + if updateKey.AzureKeyConfig.TenantID != nil && + oldRedactedKey.AzureKeyConfig.TenantID != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.TenantID.IsRedacted() && + updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) { + mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID + } + } + + if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil { + if updateKey.VertexKeyConfig.ProjectID.IsRedacted() && + updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) { + mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID + } + if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() && + updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) { + mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber + } + if updateKey.VertexKeyConfig.Region.IsRedacted() && + updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) { + mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region + } + if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() && + updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) { + mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials + } + } + + if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil { + if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() && + updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) { + mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey + } + if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() && + updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) { + mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey + } + if updateKey.BedrockKeyConfig.SessionToken != nil && + oldRedactedKey.BedrockKeyConfig.SessionToken != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.SessionToken.IsRedacted() && + updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) { + mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken + } + if updateKey.BedrockKeyConfig.Region != nil && + oldRedactedKey.BedrockKeyConfig.Region != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.Region.IsRedacted() && + updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) { + mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region + } + if updateKey.BedrockKeyConfig.ARN != nil && + oldRedactedKey.BedrockKeyConfig.ARN != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.ARN.IsRedacted() && + updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) { + mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN + } + if updateKey.BedrockKeyConfig.RoleARN != nil && + oldRedactedKey.BedrockKeyConfig.RoleARN != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.RoleARN.IsRedacted() && + updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) { + mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN + } + if updateKey.BedrockKeyConfig.ExternalID != nil && + oldRedactedKey.BedrockKeyConfig.ExternalID != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.ExternalID.IsRedacted() && + updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) { + mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID + } + if updateKey.BedrockKeyConfig.RoleSessionName != nil && + oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() && + updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) { + mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName + } + } + + if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil { + if updateKey.VLLMKeyConfig.URL.IsRedacted() && + updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) { + mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL + } + } + + // ReplicateKeyConfig has no sensitive fields β€” pass through as-is + if updateKey.ReplicateKeyConfig == nil && oldRawKey.ReplicateKeyConfig != nil { + mergedKey.ReplicateKeyConfig = oldRawKey.ReplicateKeyConfig + } + + if updateKey.OllamaKeyConfig != nil && oldRedactedKey.OllamaKeyConfig != nil && oldRawKey.OllamaKeyConfig != nil { + if updateKey.OllamaKeyConfig.URL.IsRedacted() && + updateKey.OllamaKeyConfig.URL.Equals(&oldRedactedKey.OllamaKeyConfig.URL) { + mergedKey.OllamaKeyConfig.URL = oldRawKey.OllamaKeyConfig.URL + } + } + + if updateKey.SGLKeyConfig != nil && oldRedactedKey.SGLKeyConfig != nil && oldRawKey.SGLKeyConfig != nil { + if updateKey.SGLKeyConfig.URL.IsRedacted() && + updateKey.SGLKeyConfig.URL.Equals(&oldRedactedKey.SGLKeyConfig.URL) { + mergedKey.SGLKeyConfig.URL = oldRawKey.SGLKeyConfig.URL + } + } + + mergedKey.ConfigHash = oldRawKey.ConfigHash + mergedKey.Status = oldRawKey.Status + + return mergedKey +} + +func getKeyIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) { + keyValue := ctx.UserValue("key_id") + if keyValue == nil { + return "", fmt.Errorf("missing key_id parameter") + } + + keyID, ok := keyValue.(string) + if !ok || keyID == "" { + return "", fmt.Errorf("invalid key_id parameter") + } + + decoded, err := url.PathUnescape(keyID) + if err != nil { + return "", fmt.Errorf("invalid key_id parameter encoding: %v", err) + } + + return decoded, nil +} + +// validateProviderKeyURL checks that Ollama/SGL keys have a server URL configured. +func validateProviderKeyURL(provider schemas.ModelProvider, key schemas.Key) error { + switch provider { + case schemas.Ollama: + if key.OllamaKeyConfig == nil || !key.OllamaKeyConfig.URL.IsDefined() { + return fmt.Errorf("ollama_key_config.url is required for Ollama keys") + } + case schemas.SGL: + if key.SGLKeyConfig == nil || !key.SGLKeyConfig.URL.IsDefined() { + return fmt.Errorf("sgl_key_config.url is required for SGL keys") + } + } + return nil +} diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 625c0a9e15..51729dce1c 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -4,11 +4,9 @@ package handlers import ( "context" - "encoding/json" "errors" "fmt" "net/url" - "regexp" "slices" "sort" "strings" @@ -61,21 +59,19 @@ const ( // ProviderResponse represents the response for provider operations type ProviderResponse struct { - Name schemas.ModelProvider `json:"name"` - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings - ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config"` // Proxy configuration - SendBackRawRequest bool `json:"send_back_raw_request"` // Include raw request in BifrostResponse - SendBackRawResponse bool `json:"send_back_raw_response"` // Include raw response in BifrostResponse - StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration - PricingOverrides []schemas.ProviderPricingOverride `json:"pricing_overrides,omitempty"` // Provider-level pricing overrides - ProviderStatus ProviderStatus `json:"provider_status"` // Health/initialization status of the provider - Status string `json:"status,omitempty"` // Operational status (e.g., list_models_failed) - Description string `json:"description,omitempty"` // Error/status description - ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection + Name schemas.ModelProvider `json:"name"` + NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config"` // Proxy configuration + SendBackRawRequest bool `json:"send_back_raw_request"` // Include raw request in BifrostResponse + SendBackRawResponse bool `json:"send_back_raw_response"` // Include raw response in BifrostResponse + StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration + OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration + ProviderStatus ProviderStatus `json:"provider_status"` // Health/initialization status of the provider + Status string `json:"status,omitempty"` // Operational status (e.g., list_models_failed) + Description string `json:"description,omitempty"` // Error/status description + ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection } // ListProvidersResponse represents the response for listing all providers @@ -90,17 +86,44 @@ type ErrorResponse struct { Message string `json:"message,omitempty"` } +type providerCreatePayload struct { + Provider schemas.ModelProvider `json:"provider"` + NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` + ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` + SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` + StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` + OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration +} + +type providerUpdatePayload struct { + NetworkConfig schemas.NetworkConfig `json:"network_config"` + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` + SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` + StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` + OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration +} + // RegisterRoutes registers all provider management routes func (h *ProviderHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Provider CRUD operations r.GET("/api/providers", lib.ChainMiddlewares(h.listProviders, middlewares...)) r.GET("/api/providers/{provider}", lib.ChainMiddlewares(h.getProvider, middlewares...)) + r.GET("/api/providers/{provider}/keys", lib.ChainMiddlewares(h.listProviderKeys, middlewares...)) + r.GET("/api/providers/{provider}/keys/{key_id}", lib.ChainMiddlewares(h.getProviderKey, middlewares...)) r.POST("/api/providers", lib.ChainMiddlewares(h.addProvider, middlewares...)) + r.POST("/api/providers/{provider}/keys", lib.ChainMiddlewares(h.createProviderKey, middlewares...)) r.PUT("/api/providers/{provider}", lib.ChainMiddlewares(h.updateProvider, middlewares...)) + r.PUT("/api/providers/{provider}/keys/{key_id}", lib.ChainMiddlewares(h.updateProviderKey, middlewares...)) r.DELETE("/api/providers/{provider}", lib.ChainMiddlewares(h.deleteProvider, middlewares...)) + r.DELETE("/api/providers/{provider}/keys/{key_id}", lib.ChainMiddlewares(h.deleteProviderKey, middlewares...)) r.GET("/api/keys", lib.ChainMiddlewares(h.listKeys, middlewares...)) r.GET("/api/models", lib.ChainMiddlewares(h.listModels, middlewares...)) - r.GET("/api/models/details", lib.ChainMiddlewares(h.listModelDetails, middlewares...)) r.GET("/api/models/parameters", lib.ChainMiddlewares(h.getModelParameters, middlewares...)) r.GET("/api/models/base", lib.ChainMiddlewares(h.listBaseModels, middlewares...)) } @@ -200,21 +223,8 @@ func (h *ProviderHandler) getProvider(ctx *fasthttp.RequestCtx) { // addProvider handles POST /api/providers - Add a new provider // NOTE: This only gets called when a new custom provider is added func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { - // Payload structure - var payload = struct { - Provider schemas.ModelProvider `json:"provider"` - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings - ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` // Include raw request in BifrostResponse - SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse - StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration - PricingOverrides []schemas.ProviderPricingOverride `json:"pricing_overrides,omitempty"` // Provider-level pricing overrides - }{} - if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil { + var payload providerCreatePayload + if err := sonic.Unmarshal(ctx.PostBody(), &payload); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) return } @@ -253,10 +263,6 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { return } } - if err := validatePricingOverrides(payload.PricingOverrides); err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid pricing overrides: %v", err)) - return - } // Validate retry backoff values if NetworkConfig is provided if payload.NetworkConfig != nil { if err := validateRetryBackoff(payload.NetworkConfig); err != nil { @@ -277,7 +283,6 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { // Construct ProviderConfig from individual fields config := configstore.ProviderConfig{ - Keys: payload.Keys, NetworkConfig: payload.NetworkConfig, ProxyConfig: payload.ProxyConfig, ConcurrencyAndBufferSize: payload.ConcurrencyAndBufferSize, @@ -286,7 +291,6 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { StoreRawRequestResponse: payload.StoreRawRequestResponse != nil && *payload.StoreRawRequestResponse, CustomProviderConfig: payload.CustomProviderConfig, OpenAIConfig: payload.OpenAIConfig, - PricingOverrides: payload.PricingOverrides, } // Validate custom provider configuration before persisting if err := lib.ValidateCustomProvider(config, payload.Provider); err != nil { @@ -303,17 +307,10 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to add provider: %v", err)) return } - if h.inMemoryStore.ModelCatalog != nil { - if err := h.inMemoryStore.ModelCatalog.SetProviderPricingOverrides(payload.Provider, config.PricingOverrides); err != nil { - logger.Warn("Failed to set pricing overrides for provider %s: %v", payload.Provider, err) - } - } logger.Info("Provider %s added successfully", payload.Provider) // Attempt model discovery - err := h.attemptModelDiscovery(ctx, payload.Provider, payload.CustomProviderConfig) - - if err != nil { + if err := h.attemptModelDiscovery(ctx, payload.Provider, payload.CustomProviderConfig); err != nil { logger.Warn("Model discovery failed for provider %s: %v", payload.Provider, err) } @@ -330,7 +327,6 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { SendBackRawResponse: config.SendBackRawResponse, StoreRawRequestResponse: config.StoreRawRequestResponse, CustomProviderConfig: config.CustomProviderConfig, - PricingOverrides: config.PricingOverrides, Status: config.Status, Description: config.Description, }, ProviderStatusActive) @@ -357,26 +353,14 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { } var payload = struct { - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings - ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` // Include raw request in BifrostResponse - SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse - StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration - PricingOverrides []schemas.ProviderPricingOverride `json:"pricing_overrides,omitempty"` // Provider-level pricing overrides + Keys []schemas.Key `json:"keys"` // API keys for the provider + providerUpdatePayload }{} if err := sonic.Unmarshal(ctx.PostBody(), &payload); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) return } - if err := validatePricingOverrides(payload.PricingOverrides); err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("invalid pricing overrides: %v", err)) - return - } // Get the raw config to access actual values for merging with redacted request values oldConfigRaw, err := h.inMemoryStore.GetProviderConfigRaw(provider) @@ -392,20 +376,7 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { oldConfigRaw = &configstore.ProviderConfig{} } - oldConfigRedacted, err := h.inMemoryStore.GetProviderConfigRedacted(provider) - if err != nil { - if !errors.Is(err, lib.ErrNotFound) { - logger.Warn("Failed to get old redacted config for provider %s: %v", provider, err) - SendError(ctx, fasthttp.StatusInternalServerError, err.Error()) - return - } - } - - if oldConfigRedacted == nil { - oldConfigRedacted = &configstore.ProviderConfig{} - } - - // Construct ProviderConfig from individual fields + // Construct ProviderConfig from individual fields (keys are managed separately via /keys endpoints) config := configstore.ProviderConfig{ Keys: oldConfigRaw.Keys, NetworkConfig: oldConfigRaw.NetworkConfig, @@ -413,45 +384,11 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { ProxyConfig: oldConfigRaw.ProxyConfig, CustomProviderConfig: oldConfigRaw.CustomProviderConfig, OpenAIConfig: oldConfigRaw.OpenAIConfig, - PricingOverrides: oldConfigRaw.PricingOverrides, StoreRawRequestResponse: oldConfigRaw.StoreRawRequestResponse, Status: oldConfigRaw.Status, Description: oldConfigRaw.Description, } - // Environment variable cleanup is now handled automatically by mergeKeys function - - var keysToAdd []schemas.Key - var keysToUpdate []schemas.Key - - for _, key := range payload.Keys { - if !slices.ContainsFunc(oldConfigRaw.Keys, func(k schemas.Key) bool { - return k.ID == key.ID - }) { - // By default new keys are enabled - key.Enabled = bifrost.Ptr(true) - keysToAdd = append(keysToAdd, key) - } else { - keysToUpdate = append(keysToUpdate, key) - } - } - - var keysToDelete []schemas.Key - for _, key := range oldConfigRaw.Keys { - if !slices.ContainsFunc(payload.Keys, func(k schemas.Key) bool { - return k.ID == key.ID - }) { - keysToDelete = append(keysToDelete, key) - } - } - - keys, err := h.mergeKeys(oldConfigRaw.Keys, oldConfigRedacted.Keys, keysToAdd, keysToDelete, keysToUpdate) - if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid keys: %v", err)) - return - } - config.Keys = keys - if payload.ConcurrencyAndBufferSize.Concurrency == 0 { SendError(ctx, fasthttp.StatusBadRequest, "Concurrency must be greater than 0") return @@ -501,7 +438,6 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { config.ProxyConfig = payload.ProxyConfig config.CustomProviderConfig = payload.CustomProviderConfig config.OpenAIConfig = payload.OpenAIConfig - config.PricingOverrides = payload.PricingOverrides if payload.SendBackRawRequest != nil { config.SendBackRawRequest = *payload.SendBackRawRequest } @@ -538,12 +474,6 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider: %v", err)) return } - if h.inMemoryStore.ModelCatalog != nil { - if err := h.inMemoryStore.ModelCatalog.SetProviderPricingOverrides(provider, config.PricingOverrides); err != nil { - logger.Warn("Failed to set pricing overrides for provider %s: %v", provider, err) - } - } - // Attempt model discovery err = h.attemptModelDiscovery(ctx, provider, payload.CustomProviderConfig) @@ -564,7 +494,6 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { SendBackRawResponse: config.SendBackRawResponse, StoreRawRequestResponse: config.StoreRawRequestResponse, CustomProviderConfig: config.CustomProviderConfig, - PricingOverrides: config.PricingOverrides, Status: config.Status, Description: config.Description, }, ProviderStatusActive) @@ -661,8 +590,8 @@ type listedModel struct { // Query parameters: // - query: Filter models by name (case-insensitive partial match) // - provider: Filter by specific provider name -// - keys: Comma-separated list of key IDs to filter models accessible by those keys -// - unfiltered: If true, bypass provider-level model pool restrictions only +// - keys: Comma-separated list of provider key UUIDs to filter models accessible by those keys +// - vks: Comma-separated list of virtual key UUIDs to filter models accessible by those virtual keys // - limit: Maximum number of results to return (default: 5) func (h *ProviderHandler) listModels(ctx *fasthttp.RequestCtx) { query := parseModelListQuery(ctx, 5) @@ -891,41 +820,28 @@ func (h *ProviderHandler) getModelParameters(ctx *fasthttp.RequestCtx) { } // keyAllowsModelForList reports whether a provider key permits model for catalog listing. -func keyAllowsModelForList(provider schemas.ModelProvider, model string, key schemas.Key, modelCatalog *modelcatalog.ModelCatalog) bool { - if len(key.BlacklistedModels) > 0 && keyModelListAllowsModel(provider, model, key.BlacklistedModels, modelCatalog) { +// When a non-nil catalog is provided, it also checks whether any allowlisted +// model resolves to the same base model name as the queried model (alias matching). +func keyAllowsModelForList(key schemas.Key, model string, catalog *modelcatalog.ModelCatalog) bool { + if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) { return false } if len(key.Models) > 0 { - return keyModelListAllowsModel(provider, model, key.Models, modelCatalog) - } - return true -} - -// keyModelListAllowsModel reports whether model matches a key allow/deny list entry, -// using catalog-aware alias matching when model metadata is available. -func keyModelListAllowsModel(provider schemas.ModelProvider, model string, allowedModels []string, modelCatalog *modelcatalog.ModelCatalog) bool { - if len(allowedModels) == 0 { - return false - } - - if modelCatalog == nil { - return slices.Contains(allowedModels, model) - } - - if modelCatalog.IsModelAllowedForProvider(provider, model, allowedModels) { - return true - } - - for _, allowedModel := range allowedModels { - if strings.Contains(allowedModel, "/") { - continue - } - if modelCatalog.IsSameModel(allowedModel, model) { + if slices.Contains(key.Models, model) { return true } + // Catalog-aware alias matching: a key allowlisting "gpt-4o-2024-08-06" + // should also grant access to its base model "gpt-4o" in listings. + if catalog != nil { + for _, allowed := range key.Models { + if catalog.GetBaseModelName(allowed) == model { + return true + } + } + } + return false } - - return false + return true } // matchesModelQuery applies the shared query match used by /api/models, @@ -1010,7 +926,7 @@ func filterModelsByKeysWithAccessMap(config *configstore.ProviderConfig, provide for _, model := range models { grantedBy := make([]string, 0, len(matchedKeys)) for _, matched := range matchedKeys { - if keyAllowsModelForList(provider, model, matched.key, modelCatalog) { + if keyAllowsModelForList(matched.key, model, modelCatalog) { grantedBy = append(grantedBy, matched.id) } } @@ -1072,201 +988,6 @@ func (h *ProviderHandler) listBaseModels(ctx *fasthttp.RequestCtx) { SendJSON(ctx, ListBaseModelsResponse{Models: baseModels, Total: total}) } -// mergeKeys merges new keys with old, preserving values that are redacted in the new config -func (h *ProviderHandler) mergeKeys(oldRawKeys []schemas.Key, oldRedactedKeys []schemas.Key, keysToAdd []schemas.Key, keysToDelete []schemas.Key, keysToUpdate []schemas.Key) ([]schemas.Key, error) { - // Create a map of indices to delete - toDelete := make(map[int]bool) - for _, key := range keysToDelete { - for i, oldKey := range oldRawKeys { - if oldKey.ID == key.ID { - toDelete[i] = true - break - } - } - } - - // Create a map of updates by ID for quick lookup - updates := make(map[string]schemas.Key) - for _, key := range keysToUpdate { - updates[key.ID] = key - } - - // Map old redacted keys by ID for reliable lookup - redactedByID := make(map[string]schemas.Key) - for _, rk := range oldRedactedKeys { - redactedByID[rk.ID] = rk - } - - // Process existing keys (handle updates and deletions) - var resultKeys []schemas.Key - for i, oldRawKey := range oldRawKeys { - // Skip if this key should be deleted - if toDelete[i] { - continue - } - // Check if this key should be updated - if updateKey, exists := updates[oldRawKey.ID]; exists { - oldRedactedKey, ok := redactedByID[oldRawKey.ID] - if !ok { - oldRedactedKey = schemas.Key{} - } - mergedKey := updateKey - - // Handle redacted values - preserve old value if new value is redacted/env var AND it's the same as old redacted value - if updateKey.Value.IsRedacted() && - updateKey.Value.Equals(&oldRedactedKey.Value) { - mergedKey.Value = oldRawKey.Value - } - - // Handle Azure config redacted values - if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.Endpoint.IsRedacted() && - updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) { - mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint - } - if updateKey.AzureKeyConfig.APIVersion != nil && - oldRedactedKey.AzureKeyConfig.APIVersion != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.APIVersion.IsRedacted() && - updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) { - mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion - } - } - // handle client id and secret and tenant id - if updateKey.AzureKeyConfig.ClientID != nil && - oldRedactedKey.AzureKeyConfig.ClientID != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.ClientID.IsRedacted() && - updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) { - mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID - } - } - if updateKey.AzureKeyConfig.ClientSecret != nil && - oldRedactedKey.AzureKeyConfig.ClientSecret != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.ClientSecret.IsRedacted() && - updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) { - mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret - } - } - if updateKey.AzureKeyConfig.TenantID != nil && - oldRedactedKey.AzureKeyConfig.TenantID != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.TenantID.IsRedacted() && - updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) { - mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID - } - } - } - - // Handle Vertex config redacted values - if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil { - if updateKey.VertexKeyConfig.ProjectID.IsRedacted() && - updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) { - mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID - } - if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() && - updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) { - mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber - } - if updateKey.VertexKeyConfig.Region.IsRedacted() && - updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) { - mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region - } - if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() && - updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) { - mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials - } - } - - // Handle Bedrock config redacted values - if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() && - updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) { - mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey - } - if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() && - updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) { - mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey - } - if updateKey.BedrockKeyConfig.SessionToken != nil && - oldRedactedKey.BedrockKeyConfig.SessionToken != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.SessionToken.IsRedacted() && - updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) { - mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken - } - } - if updateKey.BedrockKeyConfig.Region != nil && - oldRedactedKey.BedrockKeyConfig.Region != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.Region.IsRedacted() && - updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) { - mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region - } - } - if updateKey.BedrockKeyConfig.ARN != nil && - oldRedactedKey.BedrockKeyConfig.ARN != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.ARN.IsRedacted() && - updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) { - mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN - } - } - if updateKey.BedrockKeyConfig.RoleARN != nil && - oldRedactedKey.BedrockKeyConfig.RoleARN != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.RoleARN.IsRedacted() && - updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) { - mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN - } - } - if updateKey.BedrockKeyConfig.ExternalID != nil && - oldRedactedKey.BedrockKeyConfig.ExternalID != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.ExternalID.IsRedacted() && - updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) { - mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID - } - } - if updateKey.BedrockKeyConfig.RoleSessionName != nil && - oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() && - updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) { - mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName - } - } - } - - // Handle VLLM config redacted values - if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil { - if updateKey.VLLMKeyConfig.URL.IsRedacted() && - updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) { - mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL - } - } - - // Preserve ConfigHash from old key (UI doesn't send it back) - mergedKey.ConfigHash = oldRawKey.ConfigHash - - // Preserve Status and Description from old key (UI doesn't send them back, they're updated by model discovery) - mergedKey.Status = oldRawKey.Status - mergedKey.Description = oldRawKey.Description - - resultKeys = append(resultKeys, mergedKey) - } else { - // Keep unchanged key - resultKeys = append(resultKeys, oldRawKey) - } - } - - // Add new keys - resultKeys = append(resultKeys, keysToAdd...) - - return resultKeys, nil -} - // attemptModelDiscovery performs model discovery with timeout func (h *ProviderHandler) attemptModelDiscovery(ctx *fasthttp.RequestCtx, provider schemas.ModelProvider, customProviderConfig *schemas.CustomProviderConfig) error { // Determine if we should attempt model discovery @@ -1300,7 +1021,6 @@ func (h *ProviderHandler) getProviderResponseFromConfig(provider schemas.ModelPr return ProviderResponse{ Name: provider, - Keys: config.Keys, NetworkConfig: *config.NetworkConfig, ConcurrencyAndBufferSize: *config.ConcurrencyAndBufferSize, ProxyConfig: config.ProxyConfig, @@ -1309,7 +1029,6 @@ func (h *ProviderHandler) getProviderResponseFromConfig(provider schemas.ModelPr StoreRawRequestResponse: config.StoreRawRequestResponse, CustomProviderConfig: config.CustomProviderConfig, OpenAIConfig: config.OpenAIConfig, - PricingOverrides: config.PricingOverrides, ProviderStatus: status, Status: config.Status, Description: config.Description, @@ -1317,101 +1036,6 @@ func (h *ProviderHandler) getProviderResponseFromConfig(provider schemas.ModelPr } } -func validatePricingOverrides(overrides []schemas.ProviderPricingOverride) error { - for i, override := range overrides { - if strings.TrimSpace(override.ModelPattern) == "" { - return fmt.Errorf("override[%d]: model_pattern is required", i) - } - - switch override.MatchType { - case schemas.PricingOverrideMatchExact: - if strings.Contains(override.ModelPattern, "*") { - return fmt.Errorf("override[%d]: exact match_type cannot include '*'", i) - } - case schemas.PricingOverrideMatchWildcard: - if !strings.Contains(override.ModelPattern, "*") { - return fmt.Errorf("override[%d]: wildcard match_type requires '*' in model_pattern", i) - } - case schemas.PricingOverrideMatchRegex: - if _, err := regexp.Compile(override.ModelPattern); err != nil { - return fmt.Errorf("override[%d]: invalid regex pattern: %w", i, err) - } - default: - return fmt.Errorf("override[%d]: unsupported match_type %q", i, override.MatchType) - } - - for _, requestType := range override.RequestTypes { - if !isSupportedOverrideRequestType(requestType) { - return fmt.Errorf("override[%d]: unsupported request_type %q", i, requestType) - } - } - - if err := validatePricingOverrideNonNegativeFields(i, override); err != nil { - return err - } - } - - return nil -} - -func isSupportedOverrideRequestType(requestType schemas.RequestType) bool { - switch requestType { - case schemas.TextCompletionRequest, - schemas.TextCompletionStreamRequest, - schemas.ChatCompletionRequest, - schemas.ChatCompletionStreamRequest, - schemas.ResponsesRequest, - schemas.ResponsesStreamRequest, - schemas.EmbeddingRequest, - schemas.RerankRequest, - schemas.SpeechRequest, - schemas.SpeechStreamRequest, - schemas.TranscriptionRequest, - schemas.TranscriptionStreamRequest, - schemas.ImageGenerationRequest, - schemas.ImageGenerationStreamRequest: - return true - default: - return false - } -} - -func validatePricingOverrideNonNegativeFields(index int, override schemas.ProviderPricingOverride) error { - optionalValues := map[string]*float64{ - "input_cost_per_token": override.InputCostPerToken, - "output_cost_per_token": override.OutputCostPerToken, - "input_cost_per_video_per_second": override.InputCostPerVideoPerSecond, - "input_cost_per_audio_per_second": override.InputCostPerAudioPerSecond, - "input_cost_per_character": override.InputCostPerCharacter, - "input_cost_per_token_above_128k_tokens": override.InputCostPerTokenAbove128kTokens, - "input_cost_per_image_above_128k_tokens": override.InputCostPerImageAbove128kTokens, - "input_cost_per_video_per_second_above_128k_tokens": override.InputCostPerVideoPerSecondAbove128kTokens, - "input_cost_per_audio_per_second_above_128k_tokens": override.InputCostPerAudioPerSecondAbove128kTokens, - "output_cost_per_token_above_128k_tokens": override.OutputCostPerTokenAbove128kTokens, - "input_cost_per_token_above_200k_tokens": override.InputCostPerTokenAbove200kTokens, - "output_cost_per_token_above_200k_tokens": override.OutputCostPerTokenAbove200kTokens, - "cache_creation_input_token_cost_above_200k_tokens": override.CacheCreationInputTokenCostAbove200kTokens, - "cache_read_input_token_cost_above_200k_tokens": override.CacheReadInputTokenCostAbove200kTokens, - "cache_read_input_token_cost": override.CacheReadInputTokenCost, - "cache_creation_input_token_cost": override.CacheCreationInputTokenCost, - "input_cost_per_token_batches": override.InputCostPerTokenBatches, - "output_cost_per_token_batches": override.OutputCostPerTokenBatches, - "input_cost_per_image_token": override.InputCostPerImageToken, - "output_cost_per_image_token": override.OutputCostPerImageToken, - "input_cost_per_image": override.InputCostPerImage, - "output_cost_per_image": override.OutputCostPerImage, - "cache_read_input_image_token_cost": override.CacheReadInputImageTokenCost, - } - - for fieldName, value := range optionalValues { - if value != nil && *value < 0 { - return fmt.Errorf("override[%d]: %s must be non-negative", index, fieldName) - } - } - - return nil -} - func getProviderFromCtx(ctx *fasthttp.RequestCtx) (schemas.ModelProvider, error) { providerValue := ctx.UserValue("provider") if providerValue == nil { diff --git a/transports/bifrost-http/handlers/realtime_client_secrets.go b/transports/bifrost-http/handlers/realtime_client_secrets.go new file mode 100644 index 0000000000..9c761d0692 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_client_secrets.go @@ -0,0 +1,416 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "mime" + "strings" + "time" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// RealtimeClientSecretsHandler exposes OpenAI-compatible HTTP routes for +// minting short-lived Realtime client secrets. +type RealtimeClientSecretsHandler struct { + client *bifrost.Bifrost + config *lib.Config + handlerStore lib.HandlerStore + routeSpecs map[string]schemas.RealtimeSessionRoute +} + +func NewRealtimeClientSecretsHandler(client *bifrost.Bifrost, config *lib.Config) *RealtimeClientSecretsHandler { + return &RealtimeClientSecretsHandler{ + client: client, + config: config, + handlerStore: config, + routeSpecs: make(map[string]schemas.RealtimeSessionRoute), + } +} + +func (h *RealtimeClientSecretsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + handler := lib.ChainMiddlewares(h.handleRequest, middlewares...) + for _, route := range h.realtimeSessionRoutes() { + h.routeSpecs[route.Path] = route + r.POST(route.Path, handler) + } +} + +func (h *RealtimeClientSecretsHandler) findGovernancePlugin() governance.BaseGovernancePlugin { + basePlugins := h.config.BasePlugins.Load() + if basePlugins == nil { + return nil + } + + for _, plugin := range *basePlugins { + if governancePlugin, ok := plugin.(governance.BaseGovernancePlugin); ok { + return governancePlugin + } + } + + return nil +} + +func (h *RealtimeClientSecretsHandler) handleRequest(ctx *fasthttp.RequestCtx) { + if !isJSONContentType(string(ctx.Request.Header.ContentType())) { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "Content-Type must be application/json", + nil, + )) + return + } + + body := append([]byte(nil), ctx.Request.Body()...) + route, ok := h.routeSpecs[string(ctx.Path())] + if !ok { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusNotFound, + "invalid_request_error", + "unsupported realtime client secret route", + nil, + )) + return + } + + providerKey, model, normalizedBody, err := resolveRealtimeClientSecretTarget(route, body) + if err != nil { + SendBifrostError(ctx, err) + return + } + + bifrostCtx, cancel := lib.ConvertToBifrostContext( + ctx, + h.handlerStore.ShouldAllowDirectKeys(), + h.config.GetHeaderMatcher(), + h.config.GetMCPHeaderCombinedAllowlist(), + ) + defer cancel() + bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if route.DefaultProvider == schemas.OpenAI { + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + } + if governanceUserID, ok := ctx.UserValue(schemas.BifrostContextKeyGovernanceUserID).(string); ok && governanceUserID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceUserID, governanceUserID) + } + if bifrostErr := h.evaluateMintingGovernance(bifrostCtx, providerKey, model); bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + provider := h.client.GetProviderByKey(providerKey) + if provider == nil { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "provider not found: "+string(providerKey), + nil, + )) + return + } + + key, keyErr := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + if keyErr != nil { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + keyErr.Error(), + keyErr, + )) + return + } + + // Resolve model aliases now that the key is selected so the forwarded body + // carries the provider's canonical model, matching wsrealtime/webrtc flows. + if resolved := key.Aliases.Resolve(model); resolved != "" && resolved != model { + model = resolved + reparsed, parseErr := schemas.ParseRealtimeClientSecretBody(normalizedBody) + if parseErr != nil { + SendBifrostError(ctx, parseErr) + return + } + rewritten, normalizeErr := normalizeRealtimeClientSecretBody(reparsed, model) + if normalizeErr != nil { + SendBifrostError(ctx, normalizeErr) + return + } + normalizedBody = rewritten + } + + sessionProvider, ok := provider.(schemas.RealtimeSessionProvider) + if !ok { + SendBifrostError(ctx, realtimeSessionNotSupportedError(providerKey, provider)) + return + } + + resp, bifrostErr := sessionProvider.CreateRealtimeClientSecret(bifrostCtx, key, route.EndpointType, normalizedBody) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + cacheRealtimeEphemeralKeyMapping( + h.handlerStore.GetKVStore(), + resp.Body, + key.ID, + bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey), + ) + + writeRealtimeClientSecretResponse(ctx, resp) +} + +func (h *RealtimeClientSecretsHandler) evaluateMintingGovernance( + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, +) *schemas.BifrostError { + governancePlugin := h.findGovernancePlugin() + if governancePlugin == nil { + return nil + } + + _, bifrostErr := governancePlugin.EvaluateGovernanceRequest(bifrostCtx, &governance.EvaluationRequest{ + VirtualKey: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey), + Provider: providerKey, + Model: model, + UserID: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyGovernanceUserID), + }, schemas.RealtimeRequest) + return bifrostErr +} + +func (h *RealtimeClientSecretsHandler) realtimeSessionRoutes() []schemas.RealtimeSessionRoute { + routes := []schemas.RealtimeSessionRoute{ + { + Path: "/v1/realtime/client_secrets", + EndpointType: schemas.RealtimeSessionEndpointClientSecrets, + }, + { + Path: "/v1/realtime/sessions", + EndpointType: schemas.RealtimeSessionEndpointSessions, + }, + } + + for _, path := range integrations.OpenAIRealtimeClientSecretPaths("/openai") { + endpointType := schemas.RealtimeSessionEndpointClientSecrets + if strings.HasSuffix(path, "/realtime/sessions") { + endpointType = schemas.RealtimeSessionEndpointSessions + } + routes = append(routes, schemas.RealtimeSessionRoute{ + Path: path, + EndpointType: endpointType, + DefaultProvider: schemas.OpenAI, + }) + } + return routes +} + +func resolveRealtimeClientSecretTarget(route schemas.RealtimeSessionRoute, body []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) { + root, err := schemas.ParseRealtimeClientSecretBody(body) + if err != nil { + return "", "", nil, err + } + + rawModel, err := schemas.ExtractRealtimeClientSecretModel(root) + if err != nil { + return "", "", nil, err + } + + defaultProvider := route.DefaultProvider + providerKey, model := schemas.ParseModelString(rawModel, defaultProvider) + if defaultProvider == "" && providerKey == "" { + return "", "", nil, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "session.model must use provider/model on /v1 realtime client secret routes", + nil, + ) + } + if providerKey == "" || model == "" { + return "", "", nil, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "session.model is required", + nil, + ) + } + + // Normalize the forwarded body so the upstream provider sees the bare model + // (strip provider prefix). Mirrors resolveRealtimeSDPTarget normalization. + normalizedBody, normalizeErr := normalizeRealtimeClientSecretBody(root, model) + if normalizeErr != nil { + return "", "", nil, normalizeErr + } + + return providerKey, model, normalizedBody, nil +} + +func normalizeRealtimeClientSecretBody(root map[string]json.RawMessage, bareModel string) ([]byte, *schemas.BifrostError) { + normalizedModel, marshalErr := json.Marshal(bareModel) + if marshalErr != nil { + return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + + // Normalize session.model if present + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 { + var session map[string]json.RawMessage + if err := json.Unmarshal(sessionJSON, &session); err == nil { + if _, hasModel := session["model"]; hasModel { + session["model"] = normalizedModel + rewritten, err := json.Marshal(session) + if err != nil { + return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode session", err) + } + root["session"] = rewritten + } + } + } + // Normalize top-level model if present + if _, ok := root["model"]; ok { + root["model"] = normalizedModel + } + + normalized, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode body", marshalErr) + } + return normalized, nil +} + +const realtimeEphemeralKeyMappingPrefix = "realtime:ephemeral-key:" + +type realtimeEphemeralKeyMapping struct { + KeyID string `json:"key_id,omitempty"` + VirtualKey string `json:"virtual_key,omitempty"` +} + +func cacheRealtimeEphemeralKeyMapping(kv schemas.KVStore, body []byte, keyID string, virtualKey string) { + if kv == nil || len(body) == 0 || strings.TrimSpace(keyID) == "" { + return + } + + token, ttl, ok := parseRealtimeEphemeralKeyMapping(body) + if !ok || strings.TrimSpace(token) == "" || ttl <= 0 { + return + } + + payload, err := json.Marshal(realtimeEphemeralKeyMapping{ + KeyID: strings.TrimSpace(keyID), + VirtualKey: strings.TrimSpace(virtualKey), + }) + if err != nil { + logger.Warn("failed to encode realtime ephemeral key mapping for key_id=%s: %v", keyID, err) + return + } + + if err := kv.SetWithTTL(buildRealtimeEphemeralKeyMappingKey(token), payload, ttl); err != nil { + logger.Warn("failed to cache realtime ephemeral key mapping for key_id=%s: %v", keyID, err) + } +} + +func parseRealtimeEphemeralKeyMapping(body []byte) (string, time.Duration, bool) { + var root map[string]json.RawMessage + if err := json.Unmarshal(body, &root); err != nil { + return "", 0, false + } + + var clientSecret struct { + Value string `json:"value"` + ExpiresAt int64 `json:"expires_at"` + } + + // OpenAI client_secrets responses expose the ephemeral token at the top level. + // Keep accepting the nested shape too so the mapping logic stays compatible + // with any provider/session endpoint variants that wrap the secret object. + if err := json.Unmarshal(body, &clientSecret); err != nil || strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 { + clientSecretRaw, ok := root["client_secret"] + if !ok || len(clientSecretRaw) == 0 || string(clientSecretRaw) == "null" { + return "", 0, false + } + if err := json.Unmarshal(clientSecretRaw, &clientSecret); err != nil { + return "", 0, false + } + } + if strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 { + return "", 0, false + } + + ttl := time.Until(time.Unix(clientSecret.ExpiresAt, 0)) + if ttl <= 0 { + return "", 0, false + } + + return clientSecret.Value, ttl, true +} + +func buildRealtimeEphemeralKeyMappingKey(token string) string { + return realtimeEphemeralKeyMappingPrefix + strings.TrimSpace(token) +} + +func realtimeSessionNotSupportedError(providerKey schemas.ModelProvider, provider schemas.Provider) *schemas.BifrostError { + if rtProvider, ok := provider.(schemas.RealtimeProvider); ok && rtProvider.SupportsRealtimeAPI() { + return newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + fmt.Sprintf("provider %s supports realtime websocket connections but not realtime client secret creation", providerKey), + nil, + ) + } + + return newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + fmt.Sprintf("provider %s does not support realtime client secret creation", providerKey), + nil, + ) +} + +func newRealtimeClientSecretHandlerError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + }, + } +} + +func writeRealtimeClientSecretResponse(ctx *fasthttp.RequestCtx, resp *schemas.BifrostPassthroughResponse) { + if resp == nil { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusInternalServerError, + "server_error", + "provider returned an empty realtime client secret response", + nil, + )) + return + } + + for key, value := range resp.Headers { + ctx.Response.Header.Set(key, value) + } + if len(ctx.Response.Header.ContentType()) == 0 { + ctx.SetContentType("application/json") + } + ctx.SetStatusCode(resp.StatusCode) + ctx.SetBody(resp.Body) +} + +func isJSONContentType(contentType string) bool { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + mediaType = strings.ToLower(mediaType) + return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json") +} diff --git a/transports/bifrost-http/handlers/realtime_client_secrets_test.go b/transports/bifrost-http/handlers/realtime_client_secrets_test.go new file mode 100644 index 0000000000..4a23782406 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_client_secrets_test.go @@ -0,0 +1,414 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +func TestResolveRealtimeClientSecretTarget(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + route schemas.RealtimeSessionRoute + body []byte + wantProvider schemas.ModelProvider + wantModel string + wantErr bool + }{ + { + name: "base route with session model", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets}, + body: []byte(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`), + wantProvider: schemas.OpenAI, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "base route with top level model", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions}, + body: []byte(`{"model":"openai/gpt-4o-realtime-preview"}`), + wantProvider: schemas.OpenAI, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "openai alias uses bare model", + route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI}, + body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + wantProvider: schemas.OpenAI, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "base route rejects bare model", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets}, + body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + wantErr: true, + }, + { + name: "missing model", + route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI}, + body: []byte(`{"session":{}}`), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gotProvider, gotModel, _, err := resolveRealtimeClientSecretTarget(tt.route, tt.body) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("resolveRealtimeClientSecretTarget() error = %v", err) + } + if gotProvider != tt.wantProvider { + t.Fatalf("provider = %q, want %q", gotProvider, tt.wantProvider) + } + if gotModel != tt.wantModel { + t.Fatalf("model = %q, want %q", gotModel, tt.wantModel) + } + }) + } +} + +func TestResolveRealtimeClientSecretTarget_NormalizesModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + route schemas.RealtimeSessionRoute + body string + wantModel string // bare model expected in normalized body + }{ + { + name: "session.model provider prefix stripped", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets}, + body: `{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "top-level model provider prefix stripped", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions}, + body: `{"model":"openai/gpt-4o-realtime-preview"}`, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "bare model unchanged on alias route", + route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI}, + body: `{"session":{"model":"gpt-4o-realtime-preview"}}`, + wantModel: "gpt-4o-realtime-preview", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, normalizedBody, err := resolveRealtimeClientSecretTarget(tt.route, []byte(tt.body)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var root map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(normalizedBody, &root); unmarshalErr != nil { + t.Fatalf("failed to unmarshal normalized body: %v", unmarshalErr) + } + + // Check session.model if present + if sessionJSON, ok := root["session"]; ok { + var session map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(sessionJSON, &session); unmarshalErr != nil { + t.Fatalf("failed to unmarshal session: %v", unmarshalErr) + } + if modelJSON, ok := session["model"]; ok { + var model string + if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil { + t.Fatalf("failed to unmarshal session.model: %v", unmarshalErr) + } + if model != tt.wantModel { + t.Fatalf("session.model = %q, want %q", model, tt.wantModel) + } + } + } + + // Check top-level model if present + if modelJSON, ok := root["model"]; ok { + var model string + if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil { + t.Fatalf("failed to unmarshal model: %v", unmarshalErr) + } + if model != tt.wantModel { + t.Fatalf("model = %q, want %q", model, tt.wantModel) + } + } + }) + } +} + +func TestParseRealtimeEphemeralKeyMapping(t *testing.T) { + t.Parallel() + + token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{ + "value": "ek_test_123", + "expires_at": 4102444800 + }`)) + if !ok { + t.Fatal("expected ephemeral mapping to be parsed") + } + if token != "ek_test_123" { + t.Fatalf("token = %q, want %q", token, "ek_test_123") + } + if ttl <= 0 { + t.Fatalf("ttl = %v, want > 0", ttl) + } +} + +func TestParseRealtimeEphemeralKeyMapping_NestedFallback(t *testing.T) { + t.Parallel() + + token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{ + "client_secret": { + "value": "ek_test_nested", + "expires_at": 4102444800 + } + }`)) + if !ok { + t.Fatal("expected nested ephemeral mapping to be parsed") + } + if token != "ek_test_nested" { + t.Fatalf("token = %q, want %q", token, "ek_test_nested") + } + if ttl <= 0 { + t.Fatalf("ttl = %v, want > 0", ttl) + } +} + +func TestCacheRealtimeEphemeralKeyMappingStoresKeyID(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + body := []byte(`{ + "value": "ek_test_456", + "expires_at": ` + "4102444800" + ` + }`) + cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "sk-bf-test") + + raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_456")) + if err != nil { + t.Fatalf("store.Get() error = %v", err) + } + value, ok := raw.([]byte) + if !ok { + t.Fatalf("cached value type = %T, want []byte", raw) + } + var mapping realtimeEphemeralKeyMapping + if err := json.Unmarshal(value, &mapping); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if mapping.KeyID != "key_123" { + t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123") + } + if mapping.VirtualKey != "sk-bf-test" { + t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test") + } +} + +func TestCacheRealtimeEphemeralKeyMappingSkipsExpiredSecrets(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + expired := time.Now().Add(-time.Minute).Unix() + body := fmt.Appendf(nil, `{ + "value": "ek_expired", + "expires_at": %d + }`, expired) + cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "") + + if _, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_expired")); err == nil { + t.Fatal("expected no cached mapping for expired token") + } +} + +func TestIsJSONContentType(t *testing.T) { + t.Parallel() + + if !isJSONContentType("application/json; charset=utf-8") { + t.Fatal("expected application/json content type to pass") + } + if !isJSONContentType("application/vnd.openai+json") { + t.Fatal("expected +json content type to pass") + } + if isJSONContentType("text/plain") { + t.Fatal("expected text/plain content type to fail") + } +} + +type mockRealtimeMintingGovernancePlugin struct { + err *schemas.BifrostError + seenUserID string + seenVirtualKey string + seenProvider schemas.ModelProvider + seenModel string + evaluateCalls int +} + +func (m *mockRealtimeMintingGovernancePlugin) GetName() string { + return governance.PluginName +} + +func (m *mockRealtimeMintingGovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *governance.EvaluationRequest, _ schemas.RequestType) (*governance.EvaluationResult, *schemas.BifrostError) { + m.evaluateCalls++ + m.seenUserID = "" + m.seenVirtualKey = "" + m.seenProvider = "" + m.seenModel = "" + if evaluationRequest != nil { + m.seenUserID = evaluationRequest.UserID + m.seenVirtualKey = evaluationRequest.VirtualKey + m.seenProvider = evaluationRequest.Provider + m.seenModel = evaluationRequest.Model + } + if ctx != nil && m.seenVirtualKey == "" { + m.seenVirtualKey = bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) + } + if m.err != nil { + return nil, m.err + } + return &governance.EvaluationResult{Decision: governance.DecisionAllow}, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPreHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPostHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest, _ *schemas.HTTPResponse) error { + return nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + return req, nil, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PostLLMHook(_ *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return result, bifrostErr, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PreMCPHook(_ *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) { + return req, nil, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PostMCPHook(_ *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) Cleanup() error { + return nil +} + +func (m *mockRealtimeMintingGovernancePlugin) GetGovernanceStore() governance.GovernanceStore { + return nil +} + +func TestRealtimeClientSecretsEvaluateMintingGovernance_RequiresAccess(t *testing.T) { + t.Parallel() + + config := &lib.Config{} + plugin := &mockRealtimeMintingGovernancePlugin{ + err: &schemas.BifrostError{ + Type: schemas.Ptr("virtual_key_required"), + StatusCode: schemas.Ptr(401), + Error: &schemas.ErrorField{ + Message: "virtual key is required. Provide a virtual key via the x-bf-vk header.", + }, + }, + } + plugins := []schemas.BasePlugin{plugin} + config.BasePlugins.Store(&plugins) + + handler := NewRealtimeClientSecretsHandler(nil, config) + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + defer bifrostCtx.Done() + + err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime") + if err == nil { + t.Fatal("expected governance error") + } + if err.StatusCode == nil { + t.Fatal("expected status code") + } + if got, want := *err.StatusCode, fasthttp.StatusUnauthorized; got != want { + t.Fatalf("status = %d, want %d", got, want) + } +} + +func TestRealtimeClientSecretsEvaluateMintingGovernance_PassesContext(t *testing.T) { + t.Parallel() + + config := &lib.Config{} + plugin := &mockRealtimeMintingGovernancePlugin{} + plugins := []schemas.BasePlugin{ + plugin, + } + config.BasePlugins.Store(&plugins) + + handler := NewRealtimeClientSecretsHandler(nil, config) + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + defer bifrostCtx.Done() + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceUserID, "user_123") + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-123") + + if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil { + t.Fatalf("unexpected governance error: %v", err) + } + if plugin.evaluateCalls != 1 { + t.Fatalf("evaluate calls = %d, want 1", plugin.evaluateCalls) + } + if plugin.seenUserID != "user_123" { + t.Fatalf("governance user id = %q, want %q", plugin.seenUserID, "user_123") + } + if plugin.seenVirtualKey != "sk-bf-123" { + t.Fatalf("virtual key = %q, want %q", plugin.seenVirtualKey, "sk-bf-123") + } + if plugin.seenProvider != schemas.OpenAI { + t.Fatalf("provider = %q, want %q", plugin.seenProvider, schemas.OpenAI) + } + if plugin.seenModel != "gpt-realtime" { + t.Fatalf("model = %q, want %q", plugin.seenModel, "gpt-realtime") + } +} + +func TestRealtimeClientSecretsEvaluateMintingGovernance_ContinuesWithoutGovernance(t *testing.T) { + t.Parallel() + + handler := NewRealtimeClientSecretsHandler(nil, &lib.Config{}) + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + defer bifrostCtx.Done() + + if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil { + t.Fatalf("unexpected governance error without plugin: %v", err) + } +} diff --git a/transports/bifrost-http/handlers/realtime_logging.go b/transports/bifrost-http/handlers/realtime_logging.go new file mode 100644 index 0000000000..3b05b1633d --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_logging.go @@ -0,0 +1,441 @@ +package handlers + +import ( + "encoding/json" + "strings" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" +) + +type realtimeTurnSource string + +const ( + realtimeTurnSourceEI realtimeTurnSource = "ei" + realtimeTurnSourceLM realtimeTurnSource = "lm" +) + +const ( + realtimeMissingTranscriptText = "[Audio transcription unavailable]" +) + +func extractRealtimeTurnSummary(event *schemas.BifrostRealtimeEvent, contentOverride string) string { + if strings.TrimSpace(contentOverride) != "" { + return strings.TrimSpace(contentOverride) + } + if event == nil { + return "" + } + if event.Error != nil && strings.TrimSpace(event.Error.Message) != "" { + return strings.TrimSpace(event.Error.Message) + } + if event.Delta != nil { + if text := strings.TrimSpace(event.Delta.Text); text != "" { + return text + } + if transcript := strings.TrimSpace(event.Delta.Transcript); transcript != "" { + return transcript + } + } + if event.Item != nil { + if summary := extractRealtimeItemSummary(event.Item); summary != "" { + return summary + } + } + if event.Session != nil && strings.TrimSpace(event.Session.Instructions) != "" { + return strings.TrimSpace(event.Session.Instructions) + } + if len(event.RawData) > 0 { + return strings.TrimSpace(string(event.RawData)) + } + return "" +} + +func extractRealtimeItemSummary(item *schemas.RealtimeItem) string { + if item == nil { + return "" + } + if summary := extractRealtimeContentSummary(item.Content); summary != "" { + return summary + } + switch { + case strings.TrimSpace(item.Output) != "": + return strings.TrimSpace(item.Output) + case strings.TrimSpace(item.Arguments) != "": + return strings.TrimSpace(item.Arguments) + case strings.TrimSpace(item.Name) != "": + return strings.TrimSpace(item.Name) + default: + return "" + } +} + +func extractRealtimeContentSummary(raw []byte) string { + if len(raw) == 0 { + return "" + } + + var decoded any + if err := sonic.Unmarshal(raw, &decoded); err != nil { + return strings.TrimSpace(string(raw)) + } + + var parts []string + collectRealtimeTextFragments(decoded, &parts) + return strings.Join(parts, " ") +} + +func collectRealtimeTextFragments(value any, parts *[]string) { + switch v := value.(type) { + case map[string]any: + for key, field := range v { + switch key { + case "text", "transcript", "input_text", "output_text", "output", "arguments": + if text, ok := field.(string); ok { + text = strings.TrimSpace(text) + if text != "" { + *parts = append(*parts, text) + } + continue + } + } + collectRealtimeTextFragments(field, parts) + } + case []any: + for _, item := range v { + collectRealtimeTextFragments(item, parts) + } + } +} + +func finalizedRealtimeInputSummary(event *schemas.BifrostRealtimeEvent) string { + if event == nil { + return "" + } + + switch event.Type { + case schemas.RTEventInputAudioTransCompleted: + if transcript := extractRealtimeExtraParamString(event, "transcript"); transcript != "" { + return transcript + } + return realtimeMissingTranscriptText + default: + if event != nil && event.Type == schemas.RTEventConversationItemDone && schemas.IsRealtimeUserInputEvent(event) { + if summary := extractRealtimeItemSummary(event.Item); summary != "" { + return summary + } + if realtimeItemHasMissingAudioTranscript(event.Item) { + return realtimeMissingTranscriptText + } + } + if schemas.IsRealtimeUserInputEvent(event) { + return extractRealtimeItemSummary(event.Item) + } + } + + return "" +} + +func pendingRealtimeInputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) { + if event == nil { + return "", "" + } + + switch event.Type { + case schemas.RTEventConversationItemRetrieved: + return "", "" + case schemas.RTEventInputAudioTransCompleted: + return realtimeEventItemID(event), finalizedRealtimeInputSummary(event) + default: + if schemas.IsRealtimeUserInputEvent(event) { + return realtimeEventItemID(event), finalizedRealtimeInputSummary(event) + } + } + + return "", "" +} + +func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool { + if item == nil || len(item.Content) == 0 { + return false + } + + var decoded []map[string]any + if err := sonic.Unmarshal(item.Content, &decoded); err != nil { + return false + } + + for _, part := range decoded { + partType, _ := part["type"].(string) + if partType != "input_audio" { + continue + } + transcript, exists := part["transcript"] + if !exists || transcript == nil { + return true + } + if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" { + return true + } + } + + return false +} + +func finalizedRealtimeToolOutputSummary(event *schemas.BifrostRealtimeEvent) string { + if !schemas.IsRealtimeToolOutputEvent(event) { + return "" + } + return extractRealtimeItemSummary(event.Item) +} + +func pendingRealtimeToolOutputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) { + if event == nil || event.Type == schemas.RTEventConversationItemRetrieved || !schemas.IsRealtimeToolOutputEvent(event) { + return "", "" + } + return realtimeEventItemID(event), finalizedRealtimeToolOutputSummary(event) +} + +func extractRealtimeExtraParamString(event *schemas.BifrostRealtimeEvent, key string) string { + if event == nil || event.ExtraParams == nil { + return "" + } + raw, ok := event.ExtraParams[key] + if !ok || len(raw) == 0 { + return "" + } + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return "" + } + return strings.TrimSpace(value) +} + +func realtimeEventItemID(event *schemas.BifrostRealtimeEvent) string { + if event == nil { + return "" + } + if event.Item != nil && strings.TrimSpace(event.Item.ID) != "" { + return strings.TrimSpace(event.Item.ID) + } + if event.Delta != nil && strings.TrimSpace(event.Delta.ItemID) != "" { + return strings.TrimSpace(event.Delta.ItemID) + } + return extractRealtimeExtraParamString(event, "item_id") +} + +func combineRealtimeInputRaw(turnInputs []bfws.RealtimeTurnInput) string { + var parts []string + for _, turnInput := range turnInputs { + if trimmed := strings.TrimSpace(turnInput.Raw); trimmed != "" { + parts = append(parts, trimmed) + } + } + return strings.Join(parts, "\n\n") +} + +type realtimeResponseDoneEnvelope struct { + Response struct { + Output []realtimeResponseDoneOutput `json:"output"` + Usage *realtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type realtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []realtimeResponseDoneContent `json:"content"` +} + +type realtimeResponseDoneContent struct { + Type string `json:"type"` + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type realtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *realtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *realtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type realtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type realtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractRealtimeTurnUsage(provider schemas.RealtimeProvider, rawMessage []byte) *schemas.BifrostLLMUsage { + if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok { + if usage := extractor.ExtractRealtimeTurnUsage(rawMessage); usage != nil { + return usage + } + } + return extractRealtimeResponseDoneUsage(rawMessage) +} + +func extractRealtimeTurnOutputMessage(provider schemas.RealtimeProvider, rawMessage []byte, contentSummary string) *schemas.ChatMessage { + if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok { + if message := extractor.ExtractRealtimeTurnOutput(rawMessage); message != nil { + if strings.TrimSpace(contentSummary) != "" && (message.Content == nil || message.Content.ContentStr == nil || strings.TrimSpace(*message.Content.ContentStr) == "") { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentSummary))} + } + return message + } + } + return buildRealtimeAssistantLogMessage(rawMessage, contentSummary) +} + +func buildRealtimeAssistantLogMessage(rawMessage []byte, contentSummary string) *schemas.ChatMessage { + contentSummary = strings.TrimSpace(contentSummary) + var parsed realtimeResponseDoneEnvelope + if len(rawMessage) > 0 && sonic.Unmarshal(rawMessage, &parsed) == nil { + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if contentSummary == "" { + contentSummary = extractRealtimeResponseDoneAssistantText(parsed.Response.Output) + } + if contentSummary != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)} + } + + toolCalls := extractRealtimeResponseDoneToolCalls(parsed.Response.Output) + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + } + } + + if message.Content != nil || message.ChatAssistantMessage != nil { + return message + } + } + + if contentSummary == "" { + return nil + } + + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)}, + } +} + +func extractRealtimeResponseDoneAssistantText(outputs []realtimeResponseDoneOutput) string { + var parts []string + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + parts = append(parts, strings.TrimSpace(block.Text)) + case strings.TrimSpace(block.Transcript) != "": + parts = append(parts, strings.TrimSpace(block.Transcript)) + case strings.TrimSpace(block.Refusal) != "": + parts = append(parts, strings.TrimSpace(block.Refusal)) + } + } + } + return strings.Join(parts, " ") +} + +func extractRealtimeResponseDoneToolCalls(outputs []realtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func extractRealtimeResponseDoneUsage(rawMessage []byte) *schemas.BifrostLLMUsage { + if len(rawMessage) == 0 { + return nil + } + + var parsed realtimeResponseDoneEnvelope + if err := sonic.Unmarshal(rawMessage, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + totalTokens := parsed.Response.Usage.TotalTokens + if totalTokens == 0 && (parsed.Response.Usage.InputTokens > 0 || parsed.Response.Usage.OutputTokens > 0) { + totalTokens = parsed.Response.Usage.InputTokens + parsed.Response.Usage.OutputTokens + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: totalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} diff --git a/transports/bifrost-http/handlers/realtime_logging_test.go b/transports/bifrost-http/handlers/realtime_logging_test.go new file mode 100644 index 0000000000..054f2ea0e9 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_logging_test.go @@ -0,0 +1,435 @@ +package handlers + +import ( + "encoding/json" + "testing" + "time" + + "github.com/maximhq/bifrost/core/providers/openai" + "github.com/maximhq/bifrost/core/schemas" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" +) + +func TestShouldAccumulateRealtimeOutput(t *testing.T) { + provider := &openai.OpenAIProvider{} + if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseTextDelta) { + t.Fatal("expected response.text.delta to accumulate output text") + } + if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseAudioTransDelta) { + t.Fatal("expected response.audio_transcript.delta to accumulate output transcript") + } + if provider.ShouldAccumulateRealtimeOutput(schemas.RTEventInputAudioTransDelta) { + t.Fatal("did not expect input audio transcription delta to accumulate assistant output") + } +} + +func TestExtractRealtimeTurnSummary(t *testing.T) { + event := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Content: []byte(`[{"type":"input_text","text":"hello from realtime"}]`), + }, + } + + got := extractRealtimeTurnSummary(event, "") + if got != "hello from realtime" { + t.Fatalf("extractRealtimeTurnSummary() = %q, want %q", got, "hello from realtime") + } +} + +func TestFinalizedRealtimeInputSummary(t *testing.T) { + userCreate := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from browser"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userCreate); got != "hello from browser" { + t.Fatalf("finalizedRealtimeInputSummary(user create) = %q, want %q", got, "hello from browser") + } + + userRetrieved := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from retrieved item"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userRetrieved); got != "hello from retrieved item" { + t.Fatalf("finalizedRealtimeInputSummary(user retrieved) = %q, want %q", got, "hello from retrieved item") + } + + userCreated := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreated, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from provider created item"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userCreated); got != "hello from provider created item" { + t.Fatalf("finalizedRealtimeInputSummary(user created) = %q, want %q", got, "hello from provider created item") + } + + userAdded := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemAdded, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from provider added item"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userAdded); got != "hello from provider added item" { + t.Fatalf("finalizedRealtimeInputSummary(user added) = %q, want %q", got, "hello from provider added item") + } + + userCreatedWithoutTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreated, + Item: &schemas.RealtimeItem{ + Role: "user", + Type: "message", + Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`), + }, + RawData: []byte(`{"type":"conversation.item.created","item":{"type":"message","role":"user","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`), + } + if got := finalizedRealtimeInputSummary(userCreatedWithoutTranscript); got != "" { + t.Fatalf("finalizedRealtimeInputSummary(user created without transcript) = %q, want empty", got) + } + + userDoneWithoutTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemDone, + Item: &schemas.RealtimeItem{ + Role: "user", + Type: "message", + Status: "completed", + Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`), + }, + RawData: []byte(`{"type":"conversation.item.done","item":{"type":"message","role":"user","status":"completed","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`), + } + if got := finalizedRealtimeInputSummary(userDoneWithoutTranscript); got != realtimeMissingTranscriptText { + t.Fatalf("finalizedRealtimeInputSummary(user done without transcript) = %q, want %q", got, realtimeMissingTranscriptText) + } + + inputTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + ExtraParams: map[string]json.RawMessage{ + "transcript": json.RawMessage(`"spoken user turn"`), + }, + } + if got := finalizedRealtimeInputSummary(inputTranscript); got != "spoken user turn" { + t.Fatalf("finalizedRealtimeInputSummary(input transcript) = %q, want %q", got, "spoken user turn") + } + + emptyInputTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + ExtraParams: map[string]json.RawMessage{ + "transcript": json.RawMessage(`""`), + }, + RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","transcript":"","usage":{"total_tokens":11}}`), + } + if got := finalizedRealtimeInputSummary(emptyInputTranscript); got != realtimeMissingTranscriptText { + t.Fatalf("finalizedRealtimeInputSummary(empty input transcript) = %q, want %q", got, realtimeMissingTranscriptText) + } + + missingInputTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","usage":{"total_tokens":11}}`), + } + if got := finalizedRealtimeInputSummary(missingInputTranscript); got != realtimeMissingTranscriptText { + t.Fatalf("finalizedRealtimeInputSummary(missing input transcript) = %q, want %q", got, realtimeMissingTranscriptText) + } + + assistantCreate := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Role: "assistant", + Content: []byte(`[{"type":"text","text":"assistant text"}]`), + }, + } + if got := finalizedRealtimeInputSummary(assistantCreate); got != "" { + t.Fatalf("finalizedRealtimeInputSummary(assistant create) = %q, want empty", got) + } +} + +func TestFinalizedRealtimeToolOutputSummary(t *testing.T) { + event := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(event); got != `{"nextResponse":"tool result"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary() = %q, want %q", got, `{"nextResponse":"tool result"}`) + } + + retrieved := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result from retrieved"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(retrieved); got != `{"nextResponse":"tool result from retrieved"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary(retrieved) = %q, want %q", got, `{"nextResponse":"tool result from retrieved"}`) + } + + created := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreated, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result from created"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(created); got != `{"nextResponse":"tool result from created"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary(created) = %q, want %q", got, `{"nextResponse":"tool result from created"}`) + } + + added := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemAdded, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result from added"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(added); got != `{"nextResponse":"tool result from added"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary(added) = %q, want %q", got, `{"nextResponse":"tool result from added"}`) + } +} + +func TestPendingRealtimeInputUpdate(t *testing.T) { + t.Parallel() + + transcriptEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "transcript": json.RawMessage(`"Hello."`), + }, + } + itemID, summary := pendingRealtimeInputUpdate(transcriptEvent) + if itemID != "item_123" || summary != "Hello." { + t.Fatalf("pendingRealtimeInputUpdate(transcript) = (%q, %q), want (%q, %q)", itemID, summary, "item_123", "Hello.") + } + + retrievedEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + ID: "item_123", + Role: "user", + Content: []byte(`[{"type":"input_text","text":"historical hello"}]`), + }, + } + itemID, summary = pendingRealtimeInputUpdate(retrievedEvent) + if itemID != "" || summary != "" { + t.Fatalf("pendingRealtimeInputUpdate(retrieved) = (%q, %q), want empty", itemID, summary) + } +} + +func TestPendingRealtimeToolOutputUpdate(t *testing.T) { + t.Parallel() + + toolOutputEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemDone, + Item: &schemas.RealtimeItem{ + ID: "item_tool_123", + Type: "function_call_output", + Output: `{"nextResponse":"tool result"}`, + }, + } + itemID, summary := pendingRealtimeToolOutputUpdate(toolOutputEvent) + if itemID != "item_tool_123" || summary != `{"nextResponse":"tool result"}` { + t.Fatalf("pendingRealtimeToolOutputUpdate(done) = (%q, %q), want (%q, %q)", itemID, summary, "item_tool_123", `{"nextResponse":"tool result"}`) + } + + retrievedToolOutputEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + ID: "item_tool_123", + Type: "function_call_output", + Output: `{"nextResponse":"historical tool result"}`, + }, + } + itemID, summary = pendingRealtimeToolOutputUpdate(retrievedToolOutputEvent) + if itemID != "" || summary != "" { + t.Fatalf("pendingRealtimeToolOutputUpdate(retrieved) = (%q, %q), want empty", itemID, summary) + } +} + +func TestBuildRealtimeTurnPostResponseUsesFullResponseDonePayload(t *testing.T) { + rawRequest := `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}` + rawResponse := []byte(`{ + "type":"response.done", + "response":{ + "output":[ + { + "id":"item_message_123", + "type":"message", + "content":[ + { + "type":"audio", + "transcript":"assistant turn text" + } + ] + } + ], + "usage":{ + "total_tokens":26, + "input_tokens":17, + "output_tokens":9, + "input_token_details":{ + "text_tokens":12, + "audio_tokens":5, + "image_tokens":0, + "cached_tokens":4 + }, + "output_token_details":{ + "text_tokens":7, + "audio_tokens":2 + } + } + } + }`) + + resp := buildRealtimeTurnPostResponse(&openai.OpenAIProvider{}, schemas.OpenAI, "gpt-4o-realtime-preview-2025-06-03", rawRequest, rawResponse, "", 4321) + if resp == nil || resp.ResponsesResponse == nil { + t.Fatal("expected realtime post response to be built") + } + if resp.ResponsesResponse.ExtraFields.Latency != 4321 { + t.Fatalf("Latency = %d, want %d", resp.ResponsesResponse.ExtraFields.Latency, 4321) + } + if resp.ResponsesResponse.Usage == nil || resp.ResponsesResponse.Usage.InputTokens != 17 || resp.ResponsesResponse.Usage.OutputTokens != 9 || resp.ResponsesResponse.Usage.TotalTokens != 26 { + t.Fatalf("Usage = %+v, want input=17 output=9 total=26", resp.ResponsesResponse.Usage) + } + if len(resp.ResponsesResponse.Output) != 1 { + t.Fatalf("len(Output) = %d, want 1", len(resp.ResponsesResponse.Output)) + } + if resp.ResponsesResponse.Output[0].Content == nil || resp.ResponsesResponse.Output[0].Content.ContentStr == nil || *resp.ResponsesResponse.Output[0].Content.ContentStr != "assistant turn text" { + t.Fatalf("Output[0].Content = %+v, want assistant turn text", resp.ResponsesResponse.Output[0].Content) + } + if got, ok := resp.ResponsesResponse.ExtraFields.RawRequest.(string); !ok || got != rawRequest { + t.Fatalf("RawRequest = %#v, want %q", resp.ResponsesResponse.ExtraFields.RawRequest, rawRequest) + } + if got, ok := resp.ResponsesResponse.ExtraFields.RawResponse.(string); !ok || got == "" { + t.Fatalf("RawResponse = %#v, want raw response string", resp.ResponsesResponse.ExtraFields.RawResponse) + } +} + +func TestFinalizeRealtimeTurnHooksWithErrorCompletesActiveHooks(t *testing.T) { + t.Parallel() + + session := bfws.NewSession(nil) + session.SetProviderSessionID("sess_provider_123") + session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`) + session.AppendRealtimeOutputText("partial assistant output") + + var ( + capturedResp *schemas.BifrostResponse + capturedErr *schemas.BifrostError + cleanedUp bool + ) + session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{ + RequestID: "req_realtime_123", + StartedAt: time.Now().Add(-time.Second), + PreHookValues: map[any]any{ + schemas.BifrostContextKeyGovernanceVirtualKeyID: "vk_123", + }, + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + capturedResp = result + capturedErr = err + return result, nil + }, + Cleanup: func() { + cleanedUp = true + }, + }) + + rawResponse := []byte(`{"type":"error","error":{"type":"server_error","message":"Virtual key is required."}}`) + postErr := finalizeRealtimeTurnHooksWithError( + nil, + nil, + session, + schemas.OpenAI, + "gpt-realtime", + nil, + schemas.RTEventError, + rawResponse, + newRealtimeWireBifrostError(401, "server_error", "Virtual key is required."), + ) + if postErr != nil { + t.Fatalf("finalizeRealtimeTurnHooksWithError() post error = %v, want nil", postErr) + } + if capturedResp != nil { + t.Fatalf("captured response = %#v, want nil", capturedResp) + } + if capturedErr == nil { + t.Fatal("expected captured error") + } + if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest { + t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest) + } + if capturedErr.ExtraFields.Provider != schemas.OpenAI { + t.Fatalf("provider = %q, want %q", capturedErr.ExtraFields.Provider, schemas.OpenAI) + } + if capturedErr.ExtraFields.OriginalModelRequested != "gpt-realtime" { + t.Fatalf("model requested = %q, want %q", capturedErr.ExtraFields.OriginalModelRequested, "gpt-realtime") + } + rawRequest, ok := capturedErr.ExtraFields.RawRequest.(string) + if !ok || rawRequest == "" { + t.Fatalf("raw request = %#v, want non-empty string", capturedErr.ExtraFields.RawRequest) + } + rawResp, ok := capturedErr.ExtraFields.RawResponse.(json.RawMessage) + if !ok || string(rawResp) != string(rawResponse) { + t.Fatalf("raw response = %#v, want %s", capturedErr.ExtraFields.RawResponse, string(rawResponse)) + } + if session.PeekRealtimeTurnHooks() != nil { + t.Fatal("expected active hooks to be cleared") + } + if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 { + t.Fatalf("remaining turn inputs = %d, want 0", len(got)) + } + if got := session.ConsumeRealtimeOutputText(); got != "" { + t.Fatalf("remaining output text = %q, want empty", got) + } + if !cleanedUp { + t.Fatal("expected realtime hook cleanup to run") + } +} + +func TestNewBifrostErrorFromRealtimeErrorCarriesRealtimeMetadata(t *testing.T) { + t.Parallel() + + rawResponse := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request_error","message":"bad request","param":"session.type"}}`) + bifrostErr := newBifrostErrorFromRealtimeError( + schemas.OpenAI, + "gpt-realtime", + rawResponse, + &schemas.RealtimeError{ + Type: "invalid_request_error", + Code: "invalid_request_error", + Message: "bad request", + Param: "session.type", + }, + ) + if bifrostErr == nil { + t.Fatal("expected bifrost error") + } + if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 400 { + t.Fatalf("status code = %#v, want 400", bifrostErr.StatusCode) + } + if bifrostErr.ExtraFields.RequestType != schemas.RealtimeRequest { + t.Fatalf("request type = %q, want %q", bifrostErr.ExtraFields.RequestType, schemas.RealtimeRequest) + } + if bifrostErr.ExtraFields.Provider != schemas.OpenAI { + t.Fatalf("provider = %q, want %q", bifrostErr.ExtraFields.Provider, schemas.OpenAI) + } + if bifrostErr.ExtraFields.OriginalModelRequested != "gpt-realtime" { + t.Fatalf("model requested = %q, want %q", bifrostErr.ExtraFields.OriginalModelRequested, "gpt-realtime") + } + rawResp, ok := bifrostErr.ExtraFields.RawResponse.(json.RawMessage) + if !ok || string(rawResp) != string(rawResponse) { + t.Fatalf("raw response = %#v, want %s", bifrostErr.ExtraFields.RawResponse, string(rawResponse)) + } + if bifrostErr.Error == nil || bifrostErr.Error.Param != "session.type" { + t.Fatalf("error param = %#v, want session.type", bifrostErr.Error) + } +} diff --git a/transports/bifrost-http/handlers/realtime_turn_pipeline.go b/transports/bifrost-http/handlers/realtime_turn_pipeline.go new file mode 100644 index 0000000000..91095e5843 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_turn_pipeline.go @@ -0,0 +1,798 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" +) + +func newRealtimeTurnContext( + baseCtx *schemas.BifrostContext, + requestID string, + sessionID string, + providerSessionID string, + source realtimeTurnSource, + eventType schemas.RealtimeEventType, + key *schemas.Key, +) *schemas.BifrostContext { + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + if baseCtx != nil { + // Realtime post-hook contexts must preserve plugin-private values written in + // pre-hooks (for example telemetry start timestamps), not just public keys. + for ctxKey, value := range baseCtx.GetUserValues() { + if value != nil { + ctx.SetValue(ctxKey, value) + } + } + } + + ctx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if requestID == "" { + requestID = uuid.NewString() + } + ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID) + resolvedSessionID := strings.TrimSpace(providerSessionID) + if resolvedSessionID == "" { + resolvedSessionID = strings.TrimSpace(sessionID) + } + if baseCtx != nil { + if externalSessionID, ok := baseCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && strings.TrimSpace(externalSessionID) != "" { + resolvedSessionID = strings.TrimSpace(externalSessionID) + } + } + if resolvedSessionID != "" { + ctx.SetValue(schemas.BifrostContextKeyParentRequestID, resolvedSessionID) + } + if strings.TrimSpace(providerSessionID) != "" { + ctx.SetValue(schemas.BifrostContextKeyRealtimeSessionID, providerSessionID) + ctx.SetValue(schemas.BifrostContextKeyRealtimeProviderSessionID, providerSessionID) + } + if source != "" { + ctx.SetValue(schemas.BifrostContextKeyRealtimeSource, string(source)) + } + if eventType != "" { + ctx.SetValue(schemas.BifrostContextKeyRealtimeEventType, string(eventType)) + } + if key != nil { + if strings.TrimSpace(key.ID) != "" { + ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID) + } + if strings.TrimSpace(key.Name) != "" { + ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name) + } + } + return ctx +} + +func applyRealtimeTurnContextValues(ctx *schemas.BifrostContext, values map[any]any) { + if ctx == nil || len(values) == 0 { + return + } + for ctxKey, value := range values { + switch ctxKey { + case schemas.BifrostContextKeyRequestID, + schemas.BifrostContextKeyParentRequestID, + schemas.BifrostContextKeyRealtimeSessionID, + schemas.BifrostContextKeyRealtimeProviderSessionID, + schemas.BifrostContextKeyRealtimeSource, + schemas.BifrostContextKeyRealtimeEventType, + schemas.BifrostContextKeyStreamStartTime, + schemas.BifrostContextKeyStreamEndIndicator: + continue + } + if value != nil { + ctx.SetValue(ctxKey, value) + } + } +} + +func setRealtimeTurnStreamContext(ctx *schemas.BifrostContext, startedAt time.Time, isFinal bool) { + if ctx == nil { + return + } + if startedAt.IsZero() { + startedAt = time.Now() + } + ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, startedAt) + if isFinal { + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + } +} + +func buildRealtimeTurnPreRequest(provider schemas.ModelProvider, model string, turnInputs []bfws.RealtimeTurnInput) *schemas.BifrostRequest { + input := make([]schemas.ResponsesMessage, 0, len(turnInputs)) + for _, turnInput := range turnInputs { + summary := strings.TrimSpace(turnInput.Summary) + if summary == "" { + continue + } + switch turnInput.Role { + case string(schemas.ChatMessageRoleTool): + itemType := schemas.ResponsesMessageTypeFunctionCallOutput + output := &schemas.ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: schemas.Ptr(summary), + } + input = append(input, schemas.ResponsesMessage{ + Type: &itemType, + ResponsesToolMessage: &schemas.ResponsesToolMessage{Output: output}, + }) + case string(schemas.ChatMessageRoleUser): + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleUser + input = append(input, schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(summary)}, + }) + } + } + + return &schemas.BifrostRequest{ + RequestType: schemas.RealtimeRequest, + ResponsesRequest: &schemas.BifrostResponsesRequest{ + Provider: provider, + Model: model, + Input: input, + }, + } +} + +func buildRealtimeTurnPostResponse( + rtProvider schemas.RealtimeProvider, + provider schemas.ModelProvider, + model string, + rawRequest string, + rawResponse []byte, + contentOverride string, + latency int64, +) *schemas.BifrostResponse { + output := buildRealtimeTurnOutputMessages(rtProvider, rawResponse, contentOverride) + resp := &schemas.BifrostResponsesResponse{ + Object: "response", + Model: model, + Output: output, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider, + OriginalModelRequested: model, + Latency: latency, + }, + } + if usage := extractRealtimeTurnUsage(rtProvider, rawResponse); usage != nil { + resp.Usage = buildRealtimeResponsesUsage(usage) + } + if strings.TrimSpace(rawRequest) != "" { + resp.ExtraFields.RawRequest = rawRequest + } + if len(rawResponse) > 0 { + resp.ExtraFields.RawResponse = string(rawResponse) + } + + return &schemas.BifrostResponse{ResponsesResponse: resp} +} + +func buildRealtimeTurnOutputMessages(rtProvider schemas.RealtimeProvider, rawResponse []byte, contentOverride string) []schemas.ResponsesMessage { + outputs := make([]schemas.ResponsesMessage, 0) + if outputMessage := extractRealtimeTurnOutputMessage(rtProvider, rawResponse, contentOverride); outputMessage != nil { + outputs = append(outputs, buildRealtimeResponsesMessagesFromChat(outputMessage, contentOverride)...) + } + + if len(outputs) > 0 { + return outputs + } + + var parsed realtimeResponseDoneEnvelope + if len(rawResponse) > 0 && schemas.Unmarshal(rawResponse, &parsed) == nil { + for _, item := range parsed.Response.Output { + switch item.Type { + case "message": + content := strings.TrimSpace(contentOverride) + if content == "" { + content = extractRealtimeResponseDoneContentText(item.Content) + } + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + msg := schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Status: schemas.Ptr("completed"), + } + if strings.TrimSpace(item.ID) != "" { + msg.ID = schemas.Ptr(strings.TrimSpace(item.ID)) + } + if content != "" { + msg.Content = &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)} + } + outputs = append(outputs, msg) + case "function_call": + itemType := schemas.ResponsesMessageTypeFunctionCall + msg := schemas.ResponsesMessage{ + Type: &itemType, + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Name: schemas.Ptr(strings.TrimSpace(item.Name)), + Arguments: schemas.Ptr(item.Arguments), + }, + } + if strings.TrimSpace(item.ID) != "" { + msg.ID = schemas.Ptr(strings.TrimSpace(item.ID)) + } + if strings.TrimSpace(item.CallID) != "" { + msg.CallID = schemas.Ptr(strings.TrimSpace(item.CallID)) + } + outputs = append(outputs, msg) + } + } + } + + if len(outputs) == 0 && strings.TrimSpace(contentOverride) != "" { + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + outputs = append(outputs, schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Status: schemas.Ptr("completed"), + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentOverride))}, + }) + } + + return outputs +} + +func buildRealtimeResponsesMessagesFromChat(message *schemas.ChatMessage, contentOverride string) []schemas.ResponsesMessage { + if message == nil { + return nil + } + + outputs := make([]schemas.ResponsesMessage, 0, 1) + content := strings.TrimSpace(contentOverride) + if content == "" && message.Content != nil && message.Content.ContentStr != nil { + content = strings.TrimSpace(*message.Content.ContentStr) + } + if content != "" { + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + outputs = append(outputs, schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Status: schemas.Ptr("completed"), + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)}, + }) + } + + if message.ChatAssistantMessage == nil { + return outputs + } + + for _, toolCall := range message.ChatAssistantMessage.ToolCalls { + itemType := schemas.ResponsesMessageTypeFunctionCall + msg := schemas.ResponsesMessage{ + Type: &itemType, + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Arguments: schemas.Ptr(toolCall.Function.Arguments), + }, + } + if toolCall.Function.Name != nil { + msg.ResponsesToolMessage.Name = schemas.Ptr(strings.TrimSpace(*toolCall.Function.Name)) + } + if toolCall.ID != nil { + msg.CallID = schemas.Ptr(strings.TrimSpace(*toolCall.ID)) + msg.ID = schemas.Ptr(strings.TrimSpace(*toolCall.ID)) + } + outputs = append(outputs, msg) + } + + return outputs +} + +func extractRealtimeResponseDoneContentText(content []realtimeResponseDoneContent) string { + for _, block := range content { + switch { + case strings.TrimSpace(block.Text) != "": + return strings.TrimSpace(block.Text) + case strings.TrimSpace(block.Transcript) != "": + return strings.TrimSpace(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + return strings.TrimSpace(block.Refusal) + } + } + return "" +} + +func buildRealtimeResponsesUsage(usage *schemas.BifrostLLMUsage) *schemas.ResponsesResponseUsage { + if usage == nil { + return nil + } + result := &schemas.ResponsesResponseUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + } + if usage.PromptTokensDetails != nil { + result.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + TextTokens: usage.PromptTokensDetails.TextTokens, + AudioTokens: usage.PromptTokensDetails.AudioTokens, + ImageTokens: usage.PromptTokensDetails.ImageTokens, + CachedReadTokens: usage.PromptTokensDetails.CachedReadTokens, + CachedWriteTokens: usage.PromptTokensDetails.CachedWriteTokens, + } + } + if usage.CompletionTokensDetails != nil { + result.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ + TextTokens: usage.CompletionTokensDetails.TextTokens, + AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens, + AudioTokens: usage.CompletionTokensDetails.AudioTokens, + ImageTokens: usage.CompletionTokensDetails.ImageTokens, + ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, + RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens, + CitationTokens: usage.CompletionTokensDetails.CitationTokens, + NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries, + } + } + return result +} + +func newRealtimeTurnErrorEventPayload(bifrostErr *schemas.BifrostError) []byte { + if bifrostErr == nil { + return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`) + } + + errorType, errorCode, errorMessage, errorParam := mapRealtimeWireErrorFields(bifrostErr) + payload := schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventError, + Error: &schemas.RealtimeError{ + Type: errorType, + Code: errorCode, + Message: errorMessage, + Param: errorParam, + }, + } + if data, err := schemas.Marshal(payload); err == nil { + return data + } + return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`) +} + +// isBudgetOrBillingError returns true if the lowercased value indicates a budget or billing exhaustion error. +// Quota/rate-limit patterns (quota_exceeded, quota exceeded, etc.) are already covered by bifrost.IsRateLimitErrorMessage. +func isBudgetOrBillingError(lower string) bool { + return strings.Contains(lower, "budget_exceeded") || + strings.Contains(lower, "budget exceeded") || + strings.Contains(lower, "insufficient_quota") || + strings.Contains(lower, "hard limit reached") || + strings.Contains(lower, "billing hard limit") +} + +func mapRealtimeWireErrorFields(bifrostErr *schemas.BifrostError) (string, string, string, string) { + errorType := "server_error" + errorCode := "server_error" + errorMessage := "internal server error" + errorParam := "" + + if bifrostErr == nil { + return errorType, errorCode, errorMessage, errorParam + } + + var values []string + if bifrostErr.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Type)) + } + if bifrostErr.Error != nil { + if bifrostErr.Error.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Type)) + } + if bifrostErr.Error.Code != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Code)) + } + if strings.TrimSpace(bifrostErr.Error.Message) != "" { + errorMessage = strings.TrimSpace(bifrostErr.Error.Message) + values = append(values, errorMessage) + } + if bifrostErr.Error.Param != nil { + errorParam = strings.TrimSpace(fmt.Sprint(bifrostErr.Error.Param)) + } + } + + for _, value := range values { + lower := strings.ToLower(value) + switch { + case lower == "": + continue + case strings.Contains(lower, "invalid_request_error"): + return "invalid_request_error", "invalid_request_error", errorMessage, errorParam + case isBudgetOrBillingError(lower): + return "insufficient_quota", "insufficient_quota", errorMessage, errorParam + case bifrost.IsRateLimitErrorMessage(lower): + return "rate_limit_exceeded", "rate_limit_exceeded", errorMessage, errorParam + } + } + + return errorType, errorCode, errorMessage, errorParam +} + +func shouldGracefullyDisconnectRealtime(bifrostErr *schemas.BifrostError) bool { + if bifrostErr == nil { + return false + } + + var values []string + if bifrostErr.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Type)) + } + if bifrostErr.Error != nil { + if bifrostErr.Error.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Type)) + } + if bifrostErr.Error.Code != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Code)) + } + values = append(values, strings.TrimSpace(bifrostErr.Error.Message)) + } + + for _, value := range values { + lower := strings.ToLower(value) + if lower == "" { + continue + } + if isBudgetOrBillingError(lower) || bifrost.IsRateLimitErrorMessage(lower) { + return true + } + } + + return false +} + +func startRealtimeTurnHooks( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + rtProvider schemas.RealtimeProvider, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + startEventType schemas.RealtimeEventType, +) *schemas.BifrostError { + if client == nil || session == nil { + return &schemas.BifrostError{ + Type: schemas.Ptr("server_error"), + StatusCode: schemas.Ptr(500), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("server_error"), + Message: "realtime turn pipeline is unavailable", + }, + } + } + if !session.TryBeginRealtimeTurnHooks() { + return &schemas.BifrostError{ + Type: schemas.Ptr("invalid_request_error"), + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("invalid_request_error"), + Message: "Conversation already has an active response in progress.", + }, + } + } + committed := false + defer func() { + if !committed { + session.AbortRealtimeTurnHooks() + } + }() + + startedAt := time.Now() + turnCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, startEventType, key) + setRealtimeTurnStreamContext(turnCtx, startedAt, false) + req := buildRealtimeTurnPreRequest(provider, model, session.PeekRealtimeTurnInputs()) + hooks, bifrostErr := client.RunRealtimeTurnPreHooks(turnCtx, req) + if bifrostErr != nil { + // RunRealtimeTurnPreHooks already executed post-hooks and flushed the trace + // for this turn-start failure. Clear buffered turn state so transport-close + // fallback finalization does not emit the same error a second time. + session.ConsumeRealtimeTurnInputs() + session.ConsumeRealtimeOutputText() + return bifrostErr + } + + requestID, _ := turnCtx.Value(schemas.BifrostContextKeyRequestID).(string) + session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{ + PostHookRunner: hooks.PostHookRunner, + Cleanup: hooks.Cleanup, + RequestID: requestID, + StartedAt: startedAt, + PreHookValues: turnCtx.GetUserValues(), + }) + committed = true + return nil +} + +func finalizeRealtimeTurnHooks( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + rtProvider schemas.RealtimeProvider, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + rawResponse []byte, + contentOverride string, +) *schemas.BifrostError { + if client == nil || session == nil { + return nil + } + + turnInputs := session.ConsumeRealtimeTurnInputs() + rawRequest := combineRealtimeInputRaw(turnInputs) + + if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil { + defer func() { + if activeHooks.Cleanup != nil { + activeHooks.Cleanup() + } + }() + postResponse := buildRealtimeTurnPostResponse( + rtProvider, + provider, + model, + rawRequest, + rawResponse, + contentOverride, + time.Since(activeHooks.StartedAt).Milliseconds(), + ) + postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key) + applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues) + setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true) + _, bifrostErr := activeHooks.PostHookRunner(postCtx, postResponse, nil) + completeRealtimeTurnTrace(postCtx) + return bifrostErr + } + + startedAt := time.Now() + preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key) + setRealtimeTurnStreamContext(preCtx, startedAt, false) + preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs) + hooks, bifrostErr := client.RunRealtimeTurnPreHooks(preCtx, preReq) + if bifrostErr != nil { + return bifrostErr + } + if hooks.Cleanup != nil { + defer hooks.Cleanup() + } + + requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string) + postResponse := buildRealtimeTurnPostResponse( + rtProvider, + provider, + model, + rawRequest, + rawResponse, + contentOverride, + time.Since(startedAt).Milliseconds(), + ) + postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key) + applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues()) + setRealtimeTurnStreamContext(postCtx, startedAt, true) + _, bifrostErr = hooks.PostHookRunner(postCtx, postResponse, nil) + completeRealtimeTurnTrace(postCtx) + return bifrostErr +} + +func finalizeRealtimeTurnHooksWithError( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + eventType schemas.RealtimeEventType, + rawResponse []byte, + bifrostErr *schemas.BifrostError, +) *schemas.BifrostError { + if session == nil || bifrostErr == nil { + return nil + } + + turnInputs := session.ConsumeRealtimeTurnInputs() + rawRequest := combineRealtimeInputRaw(turnInputs) + session.ConsumeRealtimeOutputText() + + if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil { + defer func() { + if activeHooks.Cleanup != nil { + activeHooks.Cleanup() + } + }() + postErr := buildRealtimeTurnPostError( + provider, + model, + rawRequest, + rawResponse, + bifrostErr, + ) + postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key) + applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues) + setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true) + _, hookErr := activeHooks.PostHookRunner(postCtx, nil, postErr) + completeRealtimeTurnTrace(postCtx) + return hookErr + } + + if len(turnInputs) == 0 { + return nil + } + + if client == nil { + return nil + } + + startedAt := time.Now() + preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key) + setRealtimeTurnStreamContext(preCtx, startedAt, false) + preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs) + hooks, hookPreErr := client.RunRealtimeTurnPreHooks(preCtx, preReq) + if hookPreErr != nil { + return hookPreErr + } + if hooks.Cleanup != nil { + defer hooks.Cleanup() + } + + requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string) + postErr := buildRealtimeTurnPostError( + provider, + model, + rawRequest, + rawResponse, + bifrostErr, + ) + postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key) + applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues()) + setRealtimeTurnStreamContext(postCtx, startedAt, true) + _, hookErr := hooks.PostHookRunner(postCtx, nil, postErr) + completeRealtimeTurnTrace(postCtx) + return hookErr +} + +func buildRealtimeTurnPostError( + provider schemas.ModelProvider, + model string, + rawRequest string, + rawResponse []byte, + bifrostErr *schemas.BifrostError, +) *schemas.BifrostError { + if bifrostErr == nil { + return nil + } + + copied := *bifrostErr + copied.ExtraFields = bifrostErr.ExtraFields + if bifrostErr.Error != nil { + errorCopy := *bifrostErr.Error + copied.Error = &errorCopy + } + copied.ExtraFields.RequestType = schemas.RealtimeRequest + if copied.ExtraFields.Provider == "" { + copied.ExtraFields.Provider = provider + } + if strings.TrimSpace(copied.ExtraFields.OriginalModelRequested) == "" { + copied.ExtraFields.OriginalModelRequested = model + } + if strings.TrimSpace(rawRequest) != "" && copied.ExtraFields.RawRequest == nil { + copied.ExtraFields.RawRequest = rawRequest + } + if len(rawResponse) > 0 && copied.ExtraFields.RawResponse == nil { + copied.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...)) + } + return &copied +} + +func newBifrostErrorFromRealtimeError( + provider schemas.ModelProvider, + model string, + rawResponse []byte, + realtimeErr *schemas.RealtimeError, +) *schemas.BifrostError { + if realtimeErr == nil { + return nil + } + + statusCode := 500 + values := []string{ + strings.TrimSpace(realtimeErr.Type), + strings.TrimSpace(realtimeErr.Code), + strings.TrimSpace(realtimeErr.Message), + } + for _, value := range values { + lower := strings.ToLower(value) + switch { + case lower == "": + continue + case strings.Contains(lower, "invalid_request_error"): + statusCode = 400 + case isBudgetOrBillingError(lower), bifrost.IsRateLimitErrorMessage(lower): + statusCode = 429 + } + } + + errType := strings.TrimSpace(realtimeErr.Type) + if errType == "" { + errType = "server_error" + } + errCode := strings.TrimSpace(realtimeErr.Code) + if errCode == "" { + errCode = errType + } + message := strings.TrimSpace(realtimeErr.Message) + if message == "" { + message = "realtime turn failed" + } + + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(statusCode), + Type: schemas.Ptr(errType), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errType), + Code: schemas.Ptr(errCode), + Message: message, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider, + OriginalModelRequested: model, + RequestType: schemas.RealtimeRequest, + }, + } + if strings.TrimSpace(realtimeErr.Param) != "" { + bifrostErr.Error.Param = realtimeErr.Param + } + if len(rawResponse) > 0 { + bifrostErr.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...)) + } + return bifrostErr +} + +func completeRealtimeTurnTrace(ctx *schemas.BifrostContext) { + if ctx == nil { + return + } + traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if strings.TrimSpace(traceID) == "" { + return + } + tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + if tracer == nil { + return + } + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) +} + +func finalizeRealtimeTurnHooksOnTransportError( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + status int, + code string, + message string, +) *schemas.BifrostError { + return finalizeRealtimeTurnHooksWithError( + client, + baseCtx, + session, + provider, + model, + key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(status, code, message), + ) +} diff --git a/transports/bifrost-http/handlers/utils.go b/transports/bifrost-http/handlers/utils.go index 554cf3aad3..bcc35d62fa 100644 --- a/transports/bifrost-http/handlers/utils.go +++ b/transports/bifrost-http/handlers/utils.go @@ -20,6 +20,13 @@ type pluginDisabledKey struct{} // PluginDisabledKey is the context key used to indicate a plugin is being disabled. var PluginDisabledKey pluginDisabledKey +// badRequestError wraps a client input validation error so that outer handlers +// can distinguish it from internal server errors and return HTTP 400. +type badRequestError struct{ err error } + +func (e *badRequestError) Error() string { return e.err.Error() } +func (e *badRequestError) Unwrap() error { return e.err } + // SendJSON sends a JSON response with 200 OK status func SendJSON(ctx *fasthttp.RequestCtx, data interface{}) { ctx.SetContentType("application/json") @@ -115,7 +122,7 @@ func IsOriginAllowed(origin string, allowedOrigins []string) bool { return true } - if allowedOrigin == "*" { + if allowedOrigin == "*" { return true } diff --git a/transports/bifrost-http/handlers/webrtc_realtime.go b/transports/bifrost-http/handlers/webrtc_realtime.go new file mode 100644 index 0000000000..644dbc593f --- /dev/null +++ b/transports/bifrost-http/handlers/webrtc_realtime.go @@ -0,0 +1,1215 @@ +package handlers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "github.com/valyala/fasthttp" +) + +const ( + webrtcRealtimeHandshakeTimeout = 10 * time.Second + webrtcRealtimeICEGatherTimeout = 3 * time.Second + webrtcRealtimeMaxPendingMessages = 1000 +) + +var defaultAudioCodec = webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", +} + +var realtimeSDPMaxMessageSizePattern = regexp.MustCompile(`(?m)^a=max-message-size:(\d+)\s*$`) + +type WebRTCRealtimeHandler struct { + client *bifrost.Bifrost + config *lib.Config + handlerStore lib.HandlerStore + mu sync.Mutex + relays map[string]*webrtcRealtimeRelay + legacyRoutes map[string]schemas.ModelProvider // path β†’ default provider (legacy raw-SDP routes) +} + +func NewWebRTCRealtimeHandler(client *bifrost.Bifrost, config *lib.Config) *WebRTCRealtimeHandler { + return &WebRTCRealtimeHandler{ + client: client, + config: config, + handlerStore: config, + relays: make(map[string]*webrtcRealtimeRelay), + legacyRoutes: make(map[string]schemas.ModelProvider), + } +} + +func (h *WebRTCRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + handler := lib.ChainMiddlewares(h.handleRequest, middlewares...) + + // Base bifrost route β€” GA /calls format (multipart sdp + session) + r.POST("/v1/realtime/calls", handler) + + // OpenAI integration routes β€” /calls variants (GA format) + for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("/openai") { + r.POST(path, handler) + } + + // OpenAI integration routes β€” legacy variants (raw SDP, beta format) + for _, path := range integrations.OpenAIRealtimePaths("/openai") { + h.legacyRoutes[path] = schemas.OpenAI + r.POST(path, handler) + } +} + +func (h *WebRTCRealtimeHandler) Close() { + if h == nil { + return + } + + h.mu.Lock() + relays := make([]*webrtcRealtimeRelay, 0, len(h.relays)) + for _, relay := range h.relays { + relays = append(relays, relay) + } + h.mu.Unlock() + + for _, relay := range relays { + relay.closeWithShutdownSignal() + } +} + +func (h *WebRTCRealtimeHandler) handleRequest(ctx *fasthttp.RequestCtx) { + if defaultProvider, isLegacy := h.legacyRoutes[string(ctx.Path())]; isLegacy { + h.handleLegacyRequest(ctx, defaultProvider) + } else { + h.handleCallsRequest(ctx) + } +} + +// handleCallsRequest handles the GA /realtime/calls format. +// Multipart bodies strictly require both "sdp" and "session" form fields β€” +// the model is read from session.model, not from a ?model= query param. +// Raw SDP bodies (application/sdp) fall back to ?model= for the legacy +// raw-SDP path only; the multipart contract has no ?model= fallback. +func (h *WebRTCRealtimeHandler) handleCallsRequest(ctx *fasthttp.RequestCtx) { + sdpOffer, providerKey, model, normalizedSession, bifrostErr := parseCallsWebRTCRequest(ctx) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + rtProvider, bifrostErr := h.resolveWebRTCProvider(providerKey) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + exchangeSDP := func(rCtx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError) { + return rtProvider.ExchangeRealtimeWebRTCSDP(rCtx, key, model, upstreamOffer, normalizedSession) + } + + h.runWebRTCRelay(ctx, rtProvider, providerKey, model, sdpOffer, exchangeSDP) +} + +func parseCallsWebRTCRequest(ctx *fasthttp.RequestCtx) (string, schemas.ModelProvider, string, []byte, *schemas.BifrostError) { + contentType := strings.ToLower(string(ctx.Request.Header.ContentType())) + path := string(ctx.Path()) + if strings.HasPrefix(contentType, "multipart/form-data") { + form, err := ctx.MultipartForm() + if err != nil { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "failed to parse multipart form", err) + } + + sdpOffer := firstMultipartValue(form.Value, "sdp") + if strings.TrimSpace(sdpOffer) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "sdp form field is required", nil) + } + + sessionField := firstMultipartValue(form.Value, "session") + if strings.TrimSpace(sessionField) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session form field is required", nil) + } + providerKey, model, normalizedSession, bifrostErr := resolveRealtimeSDPTarget(path, []byte(sessionField)) + if bifrostErr != nil { + return "", "", "", nil, bifrostErr + } + return sdpOffer, providerKey, model, normalizedSession, nil + } + + sdpOffer := string(ctx.Request.Body()) + if strings.TrimSpace(sdpOffer) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "SDP is required", nil) + } + + rawModel := strings.TrimSpace(string(ctx.QueryArgs().Peek("model"))) + if rawModel == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "model query param is required", nil) + } + + providerKey, model := schemas.ParseModelString(rawModel, realtimeDefaultProviderForPath(path)) + if providerKey == "" || strings.TrimSpace(model) == "" { + if realtimeDefaultProviderForPath(path) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "model must use provider/model on /v1 realtime routes", nil) + } + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "invalid model: "+rawModel, nil) + } + + return sdpOffer, providerKey, model, nil, nil +} + +// handleLegacyRequest handles the beta /realtime endpoint. +// Accepts both multipart (sdp + session) and raw SDP (application/sdp) from clients. +func (h *WebRTCRealtimeHandler) handleLegacyRequest(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) { + sdpOffer, rawModel, sessionJSON, bifrostErr := parseLegacyWebRTCRequest(ctx, defaultProvider) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + providerKey, model := schemas.ParseModelString(rawModel, defaultProvider) + if providerKey == "" || model == "" { + SendBifrostError(ctx, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "invalid model: "+rawModel, nil)) + return + } + + rtProvider, bifrostErr := h.resolveWebRTCProvider(providerKey) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + legacyProvider, ok := rtProvider.(schemas.RealtimeLegacyWebRTCProvider) + if !ok { + SendBifrostError(ctx, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider does not support legacy realtime WebRTC: "+string(providerKey), nil)) + return + } + + exchangeSDP := func(rCtx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError) { + return legacyProvider.ExchangeLegacyRealtimeWebRTCSDP(rCtx, key, upstreamOffer, sessionJSON, model) + } + + h.runWebRTCRelay(ctx, rtProvider, providerKey, model, sdpOffer, exchangeSDP) +} + +// parseLegacyWebRTCRequest extracts SDP, model, and optional session from a legacy request. +// Handles both multipart (sdp + session fields) and raw SDP (body + ?model= query param). +func parseLegacyWebRTCRequest(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) (sdpOffer, rawModel string, sessionJSON json.RawMessage, err *schemas.BifrostError) { + if strings.HasPrefix(strings.ToLower(string(ctx.Request.Header.ContentType())), "multipart/form-data") { + form, formErr := ctx.MultipartForm() + if formErr != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "failed to parse multipart form", formErr) + } + sdpOffer = firstMultipartValue(form.Value, "sdp") + if sessionField := firstMultipartValue(form.Value, "session"); sessionField != "" { + sessionJSON = json.RawMessage(sessionField) + if root, parseErr := schemas.ParseRealtimeClientSecretBody(sessionJSON); parseErr == nil { + if modelJSON, ok := root["model"]; ok { + var m string + if json.Unmarshal(modelJSON, &m) == nil { + rawModel = m + } + } + } + } + } else { + sdpOffer = string(ctx.Request.Body()) + } + + if strings.TrimSpace(sdpOffer) == "" { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "SDP is required", nil) + } + + // Query param model takes precedence + if queryModel := strings.TrimSpace(string(ctx.QueryArgs().Peek("model"))); queryModel != "" { + rawModel = queryModel + } + if rawModel == "" { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "model is required (query param or session field)", nil) + } + + return sdpOffer, rawModel, sessionJSON, nil +} + +// runWebRTCRelay is the shared relay setup: creates bifrost context, selects key, establishes relay. +func (h *WebRTCRealtimeHandler) runWebRTCRelay( + ctx *fasthttp.RequestCtx, + rtProvider schemas.RealtimeProvider, + providerKey schemas.ModelProvider, + model string, + sdpOffer string, + exchangeSDP func(ctx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError), +) { + bifrostCtx, cancel := lib.ConvertToBifrostContext( + ctx, + h.handlerStore.ShouldAllowDirectKeys(), + h.config.GetHeaderMatcher(), + h.config.GetMCPHeaderCombinedAllowlist(), + ) + defer cancel() + bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if strings.HasPrefix(string(ctx.Path()), "/openai") { + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + } + + authKey, selectedKey, err := h.resolveRealtimeWebRTCKeys(ctx, bifrostCtx, providerKey, model) + if err != nil { + SendBifrostError(ctx, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", err.Error(), nil)) + return + } + + // Resolve model alias so the provider receives the actual model identifier. + if selectedKey != nil { + model = selectedKey.Aliases.Resolve(model) + } else { + model = authKey.Aliases.Resolve(model) + } + + boundExchange := func(rCtx *schemas.BifrostContext, upstreamOffer string) (string, *schemas.BifrostError) { + return exchangeSDP(rCtx, authKey, upstreamOffer) + } + + relayCtx, relayCancel := newRealtimeRelayContext(bifrostCtx) + session := bfws.NewSession(nil) + browserAnswer, relayErr := h.establishRelay(relayCtx, relayCancel, session, rtProvider, providerKey, model, selectedKey, sdpOffer, boundExchange) + if relayErr != nil { + relayCancel() + SendBifrostError(ctx, relayErr) + return + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/sdp") + ctx.SetBodyString(browserAnswer) +} + +func (h *WebRTCRealtimeHandler) resolveRealtimeWebRTCKeys( + ctx *fasthttp.RequestCtx, + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, +) (schemas.Key, *schemas.Key, error) { + inboundToken := extractRealtimeBearerToken(ctx) + mapping, mapped := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), inboundToken) + if mapped { + applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping) + } + if isRealtimeEphemeralToken(inboundToken) && !mapped { + bifrostCtx.ClearValue(schemas.BifrostContextKeyDirectKey) + bifrostCtx.ClearValue(schemas.BifrostContextKeyAPIKeyID) + bifrostCtx.ClearValue(schemas.BifrostContextKeyAPIKeyName) + bifrostCtx.ClearValue(schemas.BifrostContextKeySelectedKeyID) + bifrostCtx.ClearValue(schemas.BifrostContextKeySelectedKeyName) + authKey := schemas.Key{Value: *schemas.NewEnvVar(inboundToken)} + return authKey, nil, nil + } + + selectedKey, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + if err != nil && mapped && mapping.KeyID != "" { + bifrostCtx.ClearValue(schemas.BifrostContextKeyAPIKeyID) + selectedKey, err = h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + } + if err != nil { + return schemas.Key{}, nil, err + } + + authKey := selectedKey + if mapped && inboundToken != "" { + authKey.Value = *schemas.NewEnvVar(inboundToken) + } + return authKey, &selectedKey, nil +} + +func lookupRealtimeEphemeralKeyMapping(kv schemas.KVStore, token string) (realtimeEphemeralKeyMapping, bool) { + if kv == nil || strings.TrimSpace(token) == "" { + return realtimeEphemeralKeyMapping{}, false + } + + raw, err := kv.Get(buildRealtimeEphemeralKeyMappingKey(token)) + if err != nil { + return realtimeEphemeralKeyMapping{}, false + } + + switch value := raw.(type) { + case string: + return parseRealtimeEphemeralKeyMappingValue([]byte(value)) + case []byte: + return parseRealtimeEphemeralKeyMappingValue(value) + default: + return realtimeEphemeralKeyMapping{}, false + } +} + +func parseRealtimeEphemeralKeyMappingValue(raw []byte) (realtimeEphemeralKeyMapping, bool) { + raw = []byte(strings.TrimSpace(string(raw))) + if len(raw) == 0 { + return realtimeEphemeralKeyMapping{}, false + } + + var mapping realtimeEphemeralKeyMapping + if err := json.Unmarshal(raw, &mapping); err == nil { + mapping.KeyID = strings.TrimSpace(mapping.KeyID) + mapping.VirtualKey = strings.TrimSpace(mapping.VirtualKey) + if mapping.KeyID != "" || mapping.VirtualKey != "" { + return mapping, true + } + } + + var keyID string + if err := json.Unmarshal(raw, &keyID); err == nil { + keyID = strings.TrimSpace(keyID) + if keyID != "" { + return realtimeEphemeralKeyMapping{KeyID: keyID}, true + } + } + + keyID = strings.TrimSpace(string(raw)) + if keyID == "" { + return realtimeEphemeralKeyMapping{}, false + } + return realtimeEphemeralKeyMapping{KeyID: keyID}, true +} + +func applyRealtimeEphemeralKeyMapping(bifrostCtx *schemas.BifrostContext, mapping realtimeEphemeralKeyMapping) { + if bifrostCtx == nil { + return + } + if mapping.VirtualKey != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, mapping.VirtualKey) + } + if mapping.KeyID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, mapping.KeyID) + } +} + +func extractRealtimeBearerToken(ctx *fasthttp.RequestCtx) string { + if ctx == nil { + return "" + } + return extractRealtimeBearerTokenFromHeader(string(ctx.Request.Header.Peek("Authorization"))) +} + +func extractRealtimeBearerTokenFromHeader(authHeader string) string { + authHeader = strings.TrimSpace(authHeader) + if len(authHeader) < len("Bearer ")+1 || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return "" + } + return strings.TrimSpace(authHeader[7:]) +} + +func isRealtimeEphemeralToken(token string) bool { + return strings.HasPrefix(strings.TrimSpace(token), "ek_") +} + +// resolveWebRTCProvider validates and returns a RealtimeProvider that supports WebRTC. +func (h *WebRTCRealtimeHandler) resolveWebRTCProvider(providerKey schemas.ModelProvider) (schemas.RealtimeProvider, *schemas.BifrostError) { + provider := h.client.GetProviderByKey(providerKey) + if provider == nil { + return nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider not found: "+string(providerKey), nil) + } + + rtProvider, ok := provider.(schemas.RealtimeProvider) + if !ok || !rtProvider.SupportsRealtimeAPI() { + return nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider does not support realtime: "+string(providerKey), nil) + } + + if !rtProvider.SupportsRealtimeWebRTC() { + return nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider does not support realtime WebRTC: "+string(providerKey), nil) + } + + return rtProvider, nil +} + +// establishRelay sets up the bidirectional WebRTC relay between the browser and the upstream provider. +// exchangeSDP is called with the upstream peer connection's SDP offer and must return the provider's +// SDP answer. This allows the handler to plug in different exchange strategies (GA calls vs legacy). +func (h *WebRTCRealtimeHandler) establishRelay( + relayCtx *schemas.BifrostContext, + relayCancel context.CancelFunc, + session *bfws.Session, + provider schemas.RealtimeProvider, + providerKey schemas.ModelProvider, + model string, + key *schemas.Key, + browserOffer string, + exchangeSDP func(ctx *schemas.BifrostContext, upstreamOffer string) (string, *schemas.BifrostError), +) (string, *schemas.BifrostError) { + downstreamPC, err := newRealtimePeerConnection() + if err != nil { + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create browser peer connection", err) + } + upstreamPC, err := newRealtimePeerConnection() + if err != nil { + _ = downstreamPC.Close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create upstream peer connection", err) + } + + relay := &webrtcRealtimeRelay{ + client: h.client, + downstreamPC: downstreamPC, + upstreamPC: upstreamPC, + session: session, + bifrostCtx: relayCtx, + cancel: relayCancel, + provider: provider, + providerKey: providerKey, + model: model, + key: key, + } + relay.onClose = func() { + h.unregisterRelay(session.ID()) + } + relay.installCloseHandlers() + h.registerRelay(session.ID(), relay) + + // Downstream local audio track carries provider audio back to the browser. + providerToBrowserTrack, err := webrtc.NewTrackLocalStaticRTP(defaultAudioCodec, "audio", "bifrost-provider-audio") + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create browser audio track", err) + } + providerToBrowserSender, err := downstreamPC.AddTrack(providerToBrowserTrack) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to attach browser audio track", err) + } + relay.providerToBrowserTrack = providerToBrowserTrack + go relay.forwardRTCP(providerToBrowserSender, upstreamPC) + + // Upstream local audio track carries browser audio to the provider. + browserToProviderTrack, err := webrtc.NewTrackLocalStaticRTP(defaultAudioCodec, "audio", "bifrost-browser-audio") + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create provider audio track", err) + } + browserToProviderSender, err := upstreamPC.AddTrack(browserToProviderTrack) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to attach provider audio track", err) + } + relay.browserToProviderTrack = browserToProviderTrack + go relay.forwardRTCP(browserToProviderSender, downstreamPC) + + relay.installTrackForwarders() + if err := relay.installDataChannelRelay(); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create upstream realtime data channel", err) + } + + if err := downstreamPC.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: browserOffer, + }); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "invalid browser SDP offer", err) + } + + upstreamOffer, err := relay.createOffer(upstreamPC) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create upstream SDP offer", err) + } + upstreamOffer = constrainRealtimeSDPMaxMessageSize(upstreamOffer, browserOffer) + + upstreamAnswer, exchangeErr := exchangeSDP(relayCtx, upstreamOffer) + if exchangeErr != nil { + relay.close() + return "", exchangeErr + } + + if err := upstreamPC.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: upstreamAnswer, + }); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusBadGateway, "upstream_connection_error", "invalid upstream SDP answer", err) + } + + waitCtx, waitCancel := context.WithTimeout(relayCtx, webrtcRealtimeHandshakeTimeout) + defer waitCancel() + + if err := relay.waitForUpstream(waitCtx); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusBadGateway, "upstream_connection_error", "upstream realtime WebRTC connection failed", err) + } + + browserAnswer, err := relay.createAnswer(downstreamPC) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create browser SDP answer", err) + } + + return browserAnswer, nil +} + +type webrtcRealtimeRelay struct { + client *bifrost.Bifrost + downstreamPC *webrtc.PeerConnection + upstreamPC *webrtc.PeerConnection + + downstreamChannel *webrtc.DataChannel + upstreamChannel *webrtc.DataChannel + + providerToBrowserTrack *webrtc.TrackLocalStaticRTP + browserToProviderTrack *webrtc.TrackLocalStaticRTP + + session *bfws.Session + bifrostCtx *schemas.BifrostContext + cancel context.CancelFunc + provider schemas.RealtimeProvider + providerKey schemas.ModelProvider + model string + key *schemas.Key + onClose func() + + closeOnce sync.Once + + channelMu sync.Mutex + pendingToUpstream []queuedDataChannelMessage + pendingToDownstream []queuedDataChannelMessage + upstreamConnectedOrError chan error +} + +type queuedDataChannelMessage struct { + payload []byte + isString bool +} + +func (r *webrtcRealtimeRelay) installCloseHandlers() { + r.upstreamConnectedOrError = make(chan error, 1) + + handleState := func(name string, pc *webrtc.PeerConnection) { + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateConnected: + if name == "upstream" { + select { + case r.upstreamConnectedOrError <- nil: + default: + } + } + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: + if name == "upstream" { + select { + case r.upstreamConnectedOrError <- fmt.Errorf("peer connection state %s", state.String()): + default: + } + } + r.close() + case webrtc.PeerConnectionStateDisconnected: + r.close() + } + }) + } + + handleState("downstream", r.downstreamPC) + handleState("upstream", r.upstreamPC) +} + +func (r *webrtcRealtimeRelay) installTrackForwarders() { + r.downstreamPC.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeAudio { + return + } + r.forwardRTPTrack(track, r.browserToProviderTrack) + }) + + r.upstreamPC.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeAudio { + return + } + r.forwardRTPTrack(track, r.providerToBrowserTrack) + }) +} + +func (r *webrtcRealtimeRelay) installDataChannelRelay() error { + label := strings.TrimSpace(r.provider.RealtimeWebRTCDataChannelLabel()) + if label == "" { + return nil + } + upstreamDC, err := r.upstreamPC.CreateDataChannel(label, nil) + if err != nil { + return err + } + r.bindUpstreamChannel(upstreamDC) + + r.downstreamPC.OnDataChannel(func(dc *webrtc.DataChannel) { + r.bindDownstreamChannel(dc) + }) + return nil +} + +func (r *webrtcRealtimeRelay) bindUpstreamChannel(dc *webrtc.DataChannel) { + r.channelMu.Lock() + r.upstreamChannel = dc + r.channelMu.Unlock() + + dc.OnOpen(func() { + r.flushPending() + }) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + r.handleUpstreamMessage(msg) + }) + dc.OnClose(func() { r.close() }) + dc.OnError(func(err error) { + logger.Warn("upstream realtime data channel error: %v", err) + r.close() + }) +} + +func (r *webrtcRealtimeRelay) bindDownstreamChannel(dc *webrtc.DataChannel) { + r.channelMu.Lock() + if r.downstreamChannel != nil { + r.channelMu.Unlock() + _ = dc.Close() + return + } + r.downstreamChannel = dc + r.channelMu.Unlock() + + dc.OnOpen(func() { + r.flushPending() + }) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + r.handleDownstreamMessage(msg) + }) + dc.OnClose(func() { r.close() }) + dc.OnError(func(err error) { + logger.Warn("browser realtime data channel error: %v", err) + r.close() + }) +} + +func (r *webrtcRealtimeRelay) createOffer(pc *webrtc.PeerConnection) (string, error) { + offer, err := pc.CreateOffer(nil) + if err != nil { + return "", err + } + gatherComplete := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(offer); err != nil { + return "", err + } + select { + case <-gatherComplete: + case <-time.After(webrtcRealtimeICEGatherTimeout): + } + if pc.LocalDescription() == nil { + return "", errors.New("local description not set") + } + return pc.LocalDescription().SDP, nil +} + +func (r *webrtcRealtimeRelay) createAnswer(pc *webrtc.PeerConnection) (string, error) { + answer, err := pc.CreateAnswer(nil) + if err != nil { + return "", err + } + gatherComplete := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(answer); err != nil { + return "", err + } + select { + case <-gatherComplete: + case <-time.After(webrtcRealtimeICEGatherTimeout): + } + if pc.LocalDescription() == nil { + return "", errors.New("local description not set") + } + return pc.LocalDescription().SDP, nil +} + +func (r *webrtcRealtimeRelay) waitForUpstream(ctx context.Context) error { + select { + case err := <-r.upstreamConnectedOrError: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func (r *webrtcRealtimeRelay) forwardRTPTrack(track *webrtc.TrackRemote, target *webrtc.TrackLocalStaticRTP) { + for { + packet, _, err := track.ReadRTP() + if err != nil { + return + } + if err := target.WriteRTP(packet); err != nil { + return + } + } +} + +func (r *webrtcRealtimeRelay) forwardRTCP(sender *webrtc.RTPSender, target *webrtc.PeerConnection) { + if sender == nil || target == nil { + return + } + buf := make([]byte, 1500) + for { + n, _, readErr := sender.Read(buf) + if readErr != nil { + return + } + pkts, parseErr := rtcp.Unmarshal(buf[:n]) + if parseErr != nil { + continue + } + if writeErr := target.WriteRTCP(pkts); writeErr != nil { + return + } + } +} + +func (r *webrtcRealtimeRelay) handleDownstreamMessage(msg webrtc.DataChannelMessage) { + event, err := schemas.ParseRealtimeEvent(msg.Data) + if err != nil { + logger.Warn("failed to parse browser realtime event: %v", err) + r.sendUpstream(msg.Data, msg.IsString) + return + } + toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event) + if toolSummary != "" { + r.session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(msg.Data)) + } + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + if inputSummary != "" { + r.session.RecordRealtimeInput(inputItemID, inputSummary, string(msg.Data)) + } + startsTurn := r.provider.ShouldStartRealtimeTurn(event) + if startsTurn { + if r.session.PeekRealtimeTurnHooks() != nil { + r.sendDownstream(newRealtimeTurnErrorEventPayload(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress.")), true) + return + } + if bifrostErr := startRealtimeTurnHooks(r.client, r.bifrostCtx, r.session, r.provider, r.providerKey, r.model, r.key, event.Type); bifrostErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(bifrostErr)) + return + } + } + + providerEvent, err := r.provider.ToProviderRealtimeEvent(event) + if err != nil { + if startsTurn { + if finalizeErr := finalizeRealtimeTurnHooksOnTransportError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + 400, + "invalid_request_error", + err.Error(), + ); finalizeErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(finalizeErr)) + return + } + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))) + return + } + logger.Warn("failed to translate browser realtime event: %v", err) + r.sendUpstream(msg.Data, msg.IsString) + return + } + r.sendUpstream(providerEvent, msg.IsString) +} + +func (r *webrtcRealtimeRelay) handleUpstreamMessage(msg webrtc.DataChannelMessage) { + event, err := r.provider.ToBifrostRealtimeEvent(msg.Data) + if err != nil { + if finalizeErr := finalizeRealtimeTurnHooksOnTransportError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + 502, + "server_error", + "failed to translate upstream realtime event", + ); finalizeErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(finalizeErr)) + return + } + logger.Warn("failed to translate upstream realtime event: %v", err) + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"))) + return + } + if event != nil { + if event.Session != nil && event.Session.ID != "" { + r.session.SetProviderSessionID(event.Session.ID) + } + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + if inputSummary != "" { + r.session.RecordRealtimeInput(inputItemID, inputSummary, string(msg.Data)) + } + if event.Delta != nil && r.provider.ShouldAccumulateRealtimeOutput(event.Type) { + r.session.AppendRealtimeOutputText(event.Delta.Text) + r.session.AppendRealtimeOutputText(event.Delta.Transcript) + } + if r.provider.ShouldStartRealtimeTurn(event) && r.session.PeekRealtimeTurnHooks() == nil { + if bifrostErr := startRealtimeTurnHooks(r.client, r.bifrostCtx, r.session, r.provider, r.providerKey, r.model, r.key, event.Type); bifrostErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(bifrostErr)) + return + } + } + } + if event != nil { + if !r.provider.ShouldForwardRealtimeEvent(event) { + return + } + if event.Type == r.provider.RealtimeTurnFinalEvent() { + contentOverride := r.session.ConsumeRealtimeOutputText() + if bifrostErr := finalizeRealtimeTurnHooks(r.client, r.bifrostCtx, r.session, r.provider, r.providerKey, r.model, r.key, msg.Data, contentOverride); bifrostErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(bifrostErr)) + return + } + } else if event.Error != nil { + if finalizeErr := finalizeRealtimeTurnHooksWithError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + event.Type, + msg.Data, + newBifrostErrorFromRealtimeError(r.providerKey, r.model, msg.Data, event.Error), + ); finalizeErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(finalizeErr)) + return + } + } + msg.Data, err = r.provider.ToProviderRealtimeEvent(event) + if err != nil { + logger.Warn("failed to encode translated realtime event: %v", err) + // Lifecycle events (response.done / error) must reach the client so it + // can transition turn state β€” if encoding fails after the turn was + // finalized server-side, swallowing this would leave the client hung. + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload( + newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event: "+err.Error()), + )) + return + } + } + + r.sendDownstream(msg.Data, msg.IsString) +} + +func (r *webrtcRealtimeRelay) sendUpstream(payload []byte, isString bool) { + r.channelMu.Lock() + defer r.channelMu.Unlock() + if isDataChannelOpen(r.upstreamChannel) { + sendDataChannelMessage(r.upstreamChannel, payload, isString) + return + } + if len(r.pendingToUpstream) >= webrtcRealtimeMaxPendingMessages { + logger.Warn("upstream pending buffer exceeded %d messages, closing relay", webrtcRealtimeMaxPendingMessages) + go r.close() + return + } + r.pendingToUpstream = append(r.pendingToUpstream, queuedDataChannelMessage{payload: append([]byte(nil), payload...), isString: isString}) +} + +func (r *webrtcRealtimeRelay) sendDownstream(payload []byte, isString bool) { + r.channelMu.Lock() + defer r.channelMu.Unlock() + if isDataChannelOpen(r.downstreamChannel) { + sendDataChannelMessage(r.downstreamChannel, payload, isString) + return + } + if len(r.pendingToDownstream) >= webrtcRealtimeMaxPendingMessages { + logger.Warn("downstream pending buffer exceeded %d messages, closing relay", webrtcRealtimeMaxPendingMessages) + go r.close() + return + } + r.pendingToDownstream = append(r.pendingToDownstream, queuedDataChannelMessage{payload: append([]byte(nil), payload...), isString: isString}) +} + +func (r *webrtcRealtimeRelay) flushPending() { + r.channelMu.Lock() + defer r.channelMu.Unlock() + + if isDataChannelOpen(r.upstreamChannel) && len(r.pendingToUpstream) > 0 { + for _, msg := range r.pendingToUpstream { + sendDataChannelMessage(r.upstreamChannel, msg.payload, msg.isString) + } + r.pendingToUpstream = nil + } + if isDataChannelOpen(r.downstreamChannel) && len(r.pendingToDownstream) > 0 { + for _, msg := range r.pendingToDownstream { + sendDataChannelMessage(r.downstreamChannel, msg.payload, msg.isString) + } + r.pendingToDownstream = nil + } +} + +func (r *webrtcRealtimeRelay) close() { + r.closeOnce.Do(func() { + if r.session != nil { + _ = finalizeRealtimeTurnHooksOnTransportError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + 502, + "connection_closed", + "realtime WebRTC session closed before turn completed", + ) + r.session.ClearRealtimeTurnHooks() + } + + if r.onClose != nil { + r.onClose() + } + if r.cancel != nil { + r.cancel() + } + + r.channelMu.Lock() + if r.downstreamChannel != nil { + _ = r.downstreamChannel.Close() + } + if r.upstreamChannel != nil { + _ = r.upstreamChannel.Close() + } + r.channelMu.Unlock() + + if r.downstreamPC != nil { + _ = r.downstreamPC.Close() + } + if r.upstreamPC != nil { + _ = r.upstreamPC.Close() + } + }) +} + +func (r *webrtcRealtimeRelay) closeWithShutdownSignal() { + r.close() +} + +func (r *webrtcRealtimeRelay) closeWithErrorEvent(payload []byte) { + r.channelMu.Lock() + dc := r.downstreamChannel + r.channelMu.Unlock() + + if isDataChannelOpen(dc) && len(payload) > 0 { + sendDataChannelMessage(dc, payload, true) + go func() { + time.Sleep(100 * time.Millisecond) + r.close() + }() + return + } + + r.close() +} + +func (h *WebRTCRealtimeHandler) registerRelay(sessionID string, relay *webrtcRealtimeRelay) { + if strings.TrimSpace(sessionID) == "" || relay == nil { + return + } + h.mu.Lock() + defer h.mu.Unlock() + h.relays[sessionID] = relay +} + +func (h *WebRTCRealtimeHandler) unregisterRelay(sessionID string) { + if strings.TrimSpace(sessionID) == "" { + return + } + h.mu.Lock() + defer h.mu.Unlock() + delete(h.relays, sessionID) +} + +func newRealtimeRelayContext(requestCtx *schemas.BifrostContext) (*schemas.BifrostContext, context.CancelFunc) { + relayCtx, cancel := schemas.NewBifrostContextWithCancel(context.Background()) + if requestCtx == nil { + return relayCtx, cancel + } + + for _, key := range []any{ + schemas.BifrostContextKeyRequestID, + schemas.BifrostContextKeyHTTPRequestType, + schemas.BifrostContextKeyIntegrationType, + schemas.BifrostContextKeyParentRequestID, + schemas.BifrostContextKeyVirtualKey, + schemas.BifrostContextKeyAPIKeyName, + schemas.BifrostContextKeyAPIKeyID, + schemas.BifrostContextKeyDirectKey, + schemas.BifrostContextKeyExtraHeaders, + schemas.BifrostContextKeyRequestHeaders, + schemas.BifrostContextKeyUserAgent, + schemas.BifrostContextKeyGovernanceVirtualKeyID, + schemas.BifrostContextKeyGovernanceVirtualKeyName, + schemas.BifrostContextKeyGovernanceRoutingRuleID, + schemas.BifrostContextKeyGovernanceRoutingRuleName, + schemas.BifrostContextKeyGovernanceCustomerID, + schemas.BifrostContextKeyGovernanceCustomerName, + schemas.BifrostContextKeyGovernanceTeamID, + schemas.BifrostContextKeyGovernanceTeamName, + schemas.BifrostContextKeyGovernanceUserID, + schemas.BifrostContextKeyGovernanceIncludeOnlyKeys, + schemas.BifrostContextKeyGovernancePluginName, + schemas.BifrostContextKeySelectedKeyID, + schemas.BifrostContextKeySelectedKeyName, + schemas.BifrostContextKeyIsEnterprise, + } { + if value := requestCtx.Value(key); value != nil { + relayCtx.SetValue(key, value) + } + } + + return relayCtx, cancel +} + +func newRealtimePeerConnection() (*webrtc.PeerConnection, error) { + return webrtc.NewPeerConnection(webrtc.Configuration{}) +} + +func isDataChannelOpen(dc *webrtc.DataChannel) bool { + return dc != nil && dc.ReadyState() == webrtc.DataChannelStateOpen +} + +func realtimeEventTypeFromPayload(payload []byte) string { + var envelope struct { + Type string `json:"type"` + } + if err := json.Unmarshal(payload, &envelope); err != nil { + return "" + } + return strings.TrimSpace(envelope.Type) +} + +func parseRealtimeSDPMaxMessageSize(sdp string) (int64, bool) { + matches := realtimeSDPMaxMessageSizePattern.FindStringSubmatch(sdp) + if len(matches) < 2 { + return 0, false + } + size, err := strconv.ParseInt(matches[1], 10, 64) + if err != nil || size <= 0 { + return 0, false + } + return size, true +} + +func setRealtimeSDPMaxMessageSize(sdp string, maxMessageSize int64) string { + line := "a=max-message-size:" + strconv.FormatInt(maxMessageSize, 10) + if realtimeSDPMaxMessageSizePattern.MatchString(sdp) { + return realtimeSDPMaxMessageSizePattern.ReplaceAllString(sdp, line) + } + if strings.Contains(sdp, "\r\nm=application ") { + return strings.Replace(sdp, "\r\nm=application ", "\r\n"+line+"\r\nm=application ", 1) + } + if strings.Contains(sdp, "\nm=application ") { + return strings.Replace(sdp, "\nm=application ", "\n"+line+"\nm=application ", 1) + } + return sdp +} + +func constrainRealtimeSDPMaxMessageSize(upstreamOffer string, browserOffer string) string { + browserMax, ok := parseRealtimeSDPMaxMessageSize(browserOffer) + if !ok { + return upstreamOffer + } + + upstreamMax, ok := parseRealtimeSDPMaxMessageSize(upstreamOffer) + if ok && upstreamMax <= browserMax { + return upstreamOffer + } + + return setRealtimeSDPMaxMessageSize(upstreamOffer, browserMax) +} + +func sendDataChannelMessage(dc *webrtc.DataChannel, payload []byte, isString bool) { + if dc == nil { + return + } + var err error + if isString { + err = dc.SendText(string(payload)) + } else { + err = dc.Send(payload) + } + if err != nil { + eventType := realtimeEventTypeFromPayload(payload) + if eventType != "" { + logger.Warn("failed to send realtime data channel message: type=%s size=%d bytes err=%v", eventType, len(payload), err) + return + } + logger.Warn("failed to send realtime data channel message: size=%d bytes err=%v", len(payload), err) + } +} + +func resolveRealtimeSDPTarget(path string, sessionJSON []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) { + root, err := schemas.ParseRealtimeClientSecretBody(sessionJSON) + if err != nil { + return "", "", nil, err + } + + modelJSON, ok := root["model"] + if !ok { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + + var rawModel string + if err := json.Unmarshal(modelJSON, &rawModel); err != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model must be a string", err) + } + + providerKey, model := schemas.ParseModelString(strings.TrimSpace(rawModel), realtimeDefaultProviderForPath(path)) + if providerKey == "" || strings.TrimSpace(model) == "" { + if realtimeDefaultProviderForPath(path) == "" { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model must use provider/model on /v1 realtime routes", nil) + } + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + + normalizedModel, marshalErr := json.Marshal(model) + if marshalErr != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized session model", marshalErr) + } + root["model"] = normalizedModel + normalizedSession, marshalErr := json.Marshal(root) + if marshalErr != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized realtime session", marshalErr) + } + + return providerKey, strings.TrimSpace(model), normalizedSession, nil +} + +func firstMultipartValue(values map[string][]string, key string) string { + if len(values[key]) == 0 { + return "" + } + return values[key][0] +} + +func newRealtimeWebRTCError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + }, + } +} diff --git a/transports/bifrost-http/handlers/webrtc_realtime_test.go b/transports/bifrost-http/handlers/webrtc_realtime_test.go new file mode 100644 index 0000000000..a0c0d72c1a --- /dev/null +++ b/transports/bifrost-http/handlers/webrtc_realtime_test.go @@ -0,0 +1,346 @@ +package handlers + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" + "github.com/valyala/fasthttp" +) + +type testHandlerStore struct { + kv *kvstore.Store +} + +func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true } +func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } +func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil } +func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor { + return nil +} +func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil } +func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 } +func (s testHandlerStore) GetKVStore() *kvstore.Store { return s.kv } +func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil } + +func TestResolveRealtimeSDPTarget_BaseRouteRequiresProviderPrefix(t *testing.T) { + _, _, _, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`)) + if err == nil { + t.Fatal("expected provider/model validation error") + } + if err.Error == nil || err.Error.Message != "session.model must use provider/model on /v1 realtime routes" { + t.Fatalf("unexpected error: %#v", err) + } +} + +func TestResolveRealtimeSDPTarget_BaseRouteNormalizesModel(t *testing.T) { + provider, model, normalized, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if provider != schemas.OpenAI { + t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("unexpected normalized model: %s", model) + } + + var root map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(normalized, &root); unmarshalErr != nil { + t.Fatalf("failed to unmarshal normalized session: %v", unmarshalErr) + } + var sessionModel string + if unmarshalErr := json.Unmarshal(root["model"], &sessionModel); unmarshalErr != nil { + t.Fatalf("failed to unmarshal model: %v", unmarshalErr) + } + if sessionModel != "gpt-4o-realtime-preview" { + t.Fatalf("unexpected marshaled model: %s", sessionModel) + } +} + +func TestResolveRealtimeSDPTarget_OpenAIRouteDefaultsProvider(t *testing.T) { + provider, model, _, err := resolveRealtimeSDPTarget("/openai/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if provider != schemas.OpenAI { + t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("unexpected model: %s", model) + } +} + +func TestParseCallsWebRTCRequest_RawSDPKeepsGARoute(t *testing.T) { + var ctx fasthttp.RequestCtx + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.SetRequestURI("/openai/v1/realtime/calls?model=gpt-realtime") + ctx.Request.Header.SetContentType("application/sdp") + ctx.Request.SetBodyString("v=0\r\n") + + sdpOffer, provider, model, session, err := parseCallsWebRTCRequest(&ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sdpOffer != "v=0\r\n" { + t.Fatalf("unexpected sdp offer: %q", sdpOffer) + } + if provider != schemas.OpenAI { + t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider) + } + if model != "gpt-realtime" { + t.Fatalf("unexpected model: %s", model) + } + if session != nil { + t.Fatalf("expected nil session for raw SDP /calls request, got %s", string(session)) + } +} + +func TestNewRealtimeRelayContextCopiesValuesWithoutRequestCancellation(t *testing.T) { + requestCtx, requestCancel := schemas.NewBifrostContextWithCancel(context.Background()) + requestCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + requestCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + requestCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, "vk_test") + + relayCtx, relayCancel := newRealtimeRelayContext(requestCtx) + defer relayCancel() + + requestCancel() + + select { + case <-requestCtx.Done(): + case <-time.After(time.Second): + t.Fatal("expected request context to be cancelled") + } + + select { + case <-relayCtx.Done(): + t.Fatal("relay context should outlive cancelled request context") + default: + } + + if got := relayCtx.Value(schemas.BifrostContextKeyHTTPRequestType); got != schemas.RealtimeRequest { + t.Fatalf("request type = %v, want %v", got, schemas.RealtimeRequest) + } + if got := relayCtx.Value(schemas.BifrostContextKeyIntegrationType); got != "openai" { + t.Fatalf("integration type = %v, want %q", got, "openai") + } + if got := relayCtx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID); got != "vk_test" { + t.Fatalf("virtual key id = %v, want %q", got, "vk_test") + } +} + +func TestParseRealtimeEventPreservesExtraParams(t *testing.T) { + event, err := schemas.ParseRealtimeEvent([]byte(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } +} + +func TestExtractRealtimeBearerToken(t *testing.T) { + var ctx fasthttp.RequestCtx + ctx.Request.Header.Set("Authorization", "Bearer ek_test_123") + + if got := extractRealtimeBearerToken(&ctx); got != "ek_test_123" { + t.Fatalf("extractRealtimeBearerToken() = %q, want %q", got, "ek_test_123") + } +} + +func TestLookupRealtimeEphemeralKeyMappingKeepsEntryUntilTTLExpiry(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + payload, err := json.Marshal(realtimeEphemeralKeyMapping{KeyID: "key_123", VirtualKey: "sk-bf-test"}) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_123"), payload, time.Minute); err != nil { + t.Fatalf("store.SetWithTTL() error = %v", err) + } + + mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_123") + if !ok { + t.Fatal("expected mapping to be consumed") + } + if mapping.KeyID != "key_123" { + t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123") + } + if mapping.VirtualKey != "sk-bf-test" { + t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test") + } + + raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_123")) + if err != nil { + t.Fatalf("expected mapping to remain until TTL expiry: %v", err) + } + if raw == nil { + t.Fatal("expected mapping to remain in KV store") + } +} + +func TestLookupRealtimeEphemeralKeyMapping_BackwardsCompatibleStringValue(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_legacy"), "key_legacy", time.Minute); err != nil { + t.Fatalf("store.SetWithTTL() error = %v", err) + } + + mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_legacy") + if !ok { + t.Fatal("expected legacy mapping to be consumed") + } + if mapping.KeyID != "key_legacy" { + t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_legacy") + } + if mapping.VirtualKey != "" { + t.Fatalf("mapping.VirtualKey = %q, want empty", mapping.VirtualKey) + } +} + +func TestWebRTCRealtimeRelayCloseFinalizesActiveTurnHooks(t *testing.T) { + t.Parallel() + + session := bfws.NewSession(nil) + session.SetProviderSessionID("sess_provider_123") + session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`) + + var ( + capturedErr *schemas.BifrostError + cleanedUp bool + ) + session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{ + RequestID: "req_realtime_123", + StartedAt: time.Now().Add(-time.Second), + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + capturedErr = err + return result, nil + }, + Cleanup: func() { + cleanedUp = true + }, + }) + + relay := &webrtcRealtimeRelay{ + session: session, + providerKey: schemas.OpenAI, + model: "gpt-realtime", + } + + relay.close() + + if capturedErr == nil { + t.Fatal("expected active turn to be finalized with an error on close") + } + if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest { + t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest) + } + if capturedErr.Error == nil || capturedErr.Error.Message != "realtime WebRTC session closed before turn completed" { + t.Fatalf("error message = %#v, want realtime close message", capturedErr.Error) + } + if session.PeekRealtimeTurnHooks() != nil { + t.Fatal("expected active realtime turn hooks to be cleared") + } + if !cleanedUp { + t.Fatal("expected realtime hook cleanup to run") + } +} + +func TestResolveRealtimeWebRTCKeys_UnmappedEphemeralTokenStaysAnonymous(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + handler := &WebRTCRealtimeHandler{ + handlerStore: testHandlerStore{kv: store}, + } + + var ctx fasthttp.RequestCtx + ctx.Request.Header.Set("Authorization", "Bearer ek_test_unmapped") + + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, schemas.Key{ID: "header-provided"}) + bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyID, "selected") + bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyName, "selected-name") + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, "mapped-id") + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, "mapped-name") + + authKey, selectedKey, err := handler.resolveRealtimeWebRTCKeys(&ctx, bifrostCtx, schemas.OpenAI, "gpt-realtime") + if err != nil { + t.Fatalf("resolveRealtimeWebRTCKeys() error = %v", err) + } + if got := authKey.Value.GetValue(); got != "ek_test_unmapped" { + t.Fatalf("auth key value = %q, want %q", got, "ek_test_unmapped") + } + if selectedKey != nil { + t.Fatalf("selectedKey = %#v, want nil", selectedKey) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyDirectKey); got != nil { + t.Fatalf("direct key context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyID); got != nil { + t.Fatalf("selected key id context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyName); got != nil { + t.Fatalf("selected key name context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != nil { + t.Fatalf("api key id context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyName); got != nil { + t.Fatalf("api key name context = %#v, want nil", got) + } +} + +func TestApplyRealtimeEphemeralKeyMapping_RestoresVirtualKeyAndKeyID(t *testing.T) { + t.Parallel() + + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + applyRealtimeEphemeralKeyMapping(bifrostCtx, realtimeEphemeralKeyMapping{ + KeyID: "key_123", + VirtualKey: "sk-bf-test", + }) + + if got := bifrostCtx.Value(schemas.BifrostContextKeyVirtualKey); got != "sk-bf-test" { + t.Fatalf("virtual key context = %#v, want %q", got, "sk-bf-test") + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != "key_123" { + t.Fatalf("api key id context = %#v, want %q", got, "key_123") + } +} diff --git a/transports/bifrost-http/handlers/websocket.go b/transports/bifrost-http/handlers/websocket.go index 3f83cfdc64..93259291c8 100644 --- a/transports/bifrost-http/handlers/websocket.go +++ b/transports/bifrost-http/handlers/websocket.go @@ -180,26 +180,29 @@ func (h *WebSocketHandler) BroadcastLogUpdate(logEntry *logstore.Log) { operationType = "create" } - // Trim payload for table view: keep only the last input message and nil out - // large output fields that the table never renders. - if len(logEntry.InputHistoryParsed) > 1 { - logEntry.InputHistoryParsed = logEntry.InputHistoryParsed[len(logEntry.InputHistoryParsed)-1:] - } - if len(logEntry.ResponsesInputHistoryParsed) > 1 { - logEntry.ResponsesInputHistoryParsed = logEntry.ResponsesInputHistoryParsed[len(logEntry.ResponsesInputHistoryParsed)-1:] + // Trim payload for table view to keep websocket updates lightweight, but keep + // full realtime turns so the live table/detail sheet can still render the + // combined tool/user/assistant turn shape without waiting for a refresh. + if logEntry.Object != "realtime.turn" { + if len(logEntry.InputHistoryParsed) > 1 { + logEntry.InputHistoryParsed = logEntry.InputHistoryParsed[len(logEntry.InputHistoryParsed)-1:] + } + if len(logEntry.ResponsesInputHistoryParsed) > 1 { + logEntry.ResponsesInputHistoryParsed = logEntry.ResponsesInputHistoryParsed[len(logEntry.ResponsesInputHistoryParsed)-1:] + } + logEntry.OutputMessageParsed = nil + logEntry.ResponsesOutputParsed = nil + logEntry.EmbeddingOutputParsed = nil + logEntry.RerankOutputParsed = nil + logEntry.ParamsParsed = nil + logEntry.ToolsParsed = nil + logEntry.ToolCallsParsed = nil + logEntry.SpeechOutputParsed = nil + logEntry.TranscriptionOutputParsed = nil + logEntry.ImageGenerationOutputParsed = nil + logEntry.ListModelsOutputParsed = nil + logEntry.CacheDebugParsed = nil } - logEntry.OutputMessageParsed = nil - logEntry.ResponsesOutputParsed = nil - logEntry.EmbeddingOutputParsed = nil - logEntry.RerankOutputParsed = nil - logEntry.ParamsParsed = nil - logEntry.ToolsParsed = nil - logEntry.ToolCallsParsed = nil - logEntry.SpeechOutputParsed = nil - logEntry.TranscriptionOutputParsed = nil - logEntry.ImageGenerationOutputParsed = nil - logEntry.ListModelsOutputParsed = nil - logEntry.CacheDebugParsed = nil message := struct { Type string `json:"type"` diff --git a/transports/bifrost-http/handlers/wsrealtime.go b/transports/bifrost-http/handlers/wsrealtime.go new file mode 100644 index 0000000000..1d31c589e9 --- /dev/null +++ b/transports/bifrost-http/handlers/wsrealtime.go @@ -0,0 +1,666 @@ +package handlers + +import ( + "errors" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/fasthttp/router" + ws "github.com/fasthttp/websocket" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" + "github.com/valyala/fasthttp" +) + +const ( + realtimeWSPingInterval = 15 * time.Second + realtimeWSPongTimeout = 45 * time.Second + realtimeWSPingWriteTimeout = 10 * time.Second + realtimeWSWriteTimeout = 30 * time.Second +) + +// WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API. +type WSRealtimeHandler struct { + client *bifrost.Bifrost + config *lib.Config + handlerStore lib.HandlerStore + pool *bfws.Pool + sessions *bfws.SessionManager +} + +// NewWSRealtimeHandler creates a new Realtime WebSocket handler. +func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler { + maxConns := config.WebSocketConfig.MaxConnections + + return &WSRealtimeHandler{ + client: client, + config: config, + handlerStore: config, + pool: pool, + sessions: bfws.NewSessionManager(maxConns), + } +} + +// RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths. +func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...) + r.GET("/v1/realtime", handler) + for _, path := range integrations.OpenAIRealtimePaths("/openai") { + r.GET(path, handler) + } +} + +func (h *WSRealtimeHandler) Close() { + if h == nil || h.sessions == nil { + return + } + h.sessions.CloseAll() +} + +func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + modelParam := string(ctx.QueryArgs().Peek("model")) + deploymentParam := string(ctx.QueryArgs().Peek("deployment")) + auth := captureAuthHeaders(ctx) + // OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.". + // Extract it into the auth headers so downstream processing recognizes it. + if auth.authorization == "" { + if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" { + auth.authorization = "Bearer " + token + } + } + + providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam) + if err != nil { + upgrader := h.websocketUpgrader("") + upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) { + defer conn.Close() + clientConn := newRealtimeClientConn(conn) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) + }) + if upgradeErr != nil { + logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr) + } + return + } + + provider := h.client.GetProviderByKey(providerKey) + rtProvider, ok := provider.(schemas.RealtimeProvider) + if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() { + upgrader := h.websocketUpgrader("") + upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) { + defer conn.Close() + clientConn := newRealtimeClientConn(conn) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey))) + }) + if upgradeErr != nil { + logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr) + } + return + } + + upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol()) + err = upgrader.Upgrade(ctx, func(conn *ws.Conn) { + defer conn.Close() + clientConn := newRealtimeClientConn(conn) + + session, sessionErr := h.sessions.Create(conn) + if sessionErr != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error())) + return + } + defer h.sessions.Remove(conn) + + h.runRealtimeSession(clientConn, session, auth, path, providerKey, model) + }) + if err != nil { + logger.Warn("websocket upgrade failed for %s: %v", path, err) + } +} + +func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader { + upgrader := ws.FastHTTPUpgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { + origin := string(ctx.Request.Header.Peek("Origin")) + if origin == "" { + return true + } + return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins) + }, + } + if strings.TrimSpace(subprotocol) != "" { + upgrader.Subprotocols = []string{subprotocol} + } + return upgrader +} + +func (h *WSRealtimeHandler) runRealtimeSession( + clientConn *realtimeClientConn, + session *bfws.Session, + auth *authHeaders, + path string, + providerKey schemas.ModelProvider, + model string, +) { + clientConn.startHeartbeat() + defer clientConn.stopHeartbeat() + + bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth) + if bifrostCtx == nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context")) + return + } + defer cancel() + + // Resolve ephemeral key mapping to restore virtual key context. + token := extractRealtimeBearerTokenFromHeader(auth.authorization) + if isRealtimeEphemeralToken(token) { + mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token) + if ok { + applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping) + } + } + + bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if strings.HasPrefix(path, "/openai") { + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + } + + provider := h.client.GetProviderByKey(providerKey) + if provider == nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey))) + return + } + + rtProvider, ok := provider.(schemas.RealtimeProvider) + if !ok || !rtProvider.SupportsRealtimeAPI() { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey))) + return + } + + key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) + return + } + + // Resolve model alias so the provider receives the actual model identifier. + model = key.Aliases.Resolve(model) + + wsURL := rtProvider.RealtimeWebSocketURL(key, model) + upstream, err := h.pool.Get(bfws.PoolKey{ + Provider: providerKey, + KeyID: key.ID, + Endpoint: wsURL, + }, rtProvider.RealtimeHeaders(key)) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error())) + return + } + defer h.pool.Discard(upstream) + + errCh := make(chan error, 2) + go func() { + errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key) + }() + go func() { + errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key) + }() + + firstErr := <-errCh + _ = upstream.Close() + _ = clientConn.Close() + secondErr := <-errCh + + if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil { + logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr) + } +} + +func (h *WSRealtimeHandler) relayClientToRealtimeProvider( + clientConn *realtimeClientConn, + session *bfws.Session, + upstream *bfws.UpstreamConn, + provider schemas.RealtimeProvider, + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, + key schemas.Key, +) error { + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + finalizeRealtimeTurnHooksOnTransportError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + 499, + "client_closed_request", + "client realtime websocket disconnected before turn completed", + ) + if isNormalWebSocketClosure(err) { + return nil + } + return err + } + if messageType != ws.TextMessage { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages")) + return nil + } + + event, err := schemas.ParseRealtimeEvent(message) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON")) + continue + } + // Extract pending tool/input summaries but defer recording until the event + // passes validation β€” rejected events must not pollute session state. + toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event) + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + + startsTurn := provider.ShouldStartRealtimeTurn(event) + if startsTurn { + if session.PeekRealtimeTurnHooks() != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress.")) + continue + } + if toolSummary != "" { + session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message)) + } + if inputSummary != "" { + session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) + } + if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil { + clientConn.writeRealtimeError(bifrostErr) + return nil + } + } + + providerEvent, err := provider.ToProviderRealtimeEvent(event) + if err != nil { + if startsTurn { + if finalizeErr := finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()), + ); finalizeErr != nil { + clientConn.writeRealtimeError(finalizeErr) + return nil + } + } + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) + continue + } + + // Record tool output / input only after the event passed validation. + if !startsTurn { + if toolSummary != "" { + session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message)) + } + if inputSummary != "" { + session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) + } + } + + if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil { + finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"), + ) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream")) + return err + } + } +} + +func (h *WSRealtimeHandler) relayRealtimeProviderToClient( + clientConn *realtimeClientConn, + session *bfws.Session, + upstream *bfws.UpstreamConn, + provider schemas.RealtimeProvider, + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, + key schemas.Key, +) error { + for { + disconnectAfterWrite := false + messageType, message, err := upstream.ReadMessage() + if err != nil { + finalizeRealtimeTurnHooksOnTransportError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + 502, + "upstream_connection_error", + "upstream realtime websocket closed before turn completed", + ) + if isNormalWebSocketClosure(err) { + return nil + } + finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"), + ) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted")) + return err + } + + if messageType == ws.TextMessage { + event, err := provider.ToBifrostRealtimeEvent(message) + if err != nil { + finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + message, + newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"), + ) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event")) + return err + } + if event != nil { + if event.Session != nil && event.Session.ID != "" { + session.SetProviderSessionID(event.Session.ID) + } + if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) { + session.AppendRealtimeOutputText(event.Delta.Text) + session.AppendRealtimeOutputText(event.Delta.Transcript) + } + if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil { + if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil { + clientConn.writeRealtimeError(bifrostErr) + return nil + } + } + } + if event != nil { + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + if !provider.ShouldForwardRealtimeEvent(event) { + continue + } + if event.Type == provider.RealtimeTurnFinalEvent() { + contentOverride := session.ConsumeRealtimeOutputText() + if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil { + clientConn.writeRealtimeError(bifrostErr) + return nil + } + } else if event.Error != nil { + turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error) + finalizeErr := finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + event.Type, + message, + turnErr, + ) + if finalizeErr != nil { + clientConn.writeRealtimeError(finalizeErr) + return nil + } + // Defer the disconnect so the normal translated-write path + // below still runs β€” otherwise terminal errors from translated + // providers would reach the client in provider-native format. + disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr) + } else if inputSummary != "" { + session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) + } + if len(event.RawData) == 0 { + message, err = provider.ToProviderRealtimeEvent(event) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event")) + return err + } + } + } + } + + if err := clientConn.WriteMessage(messageType, message); err != nil { + finalizeRealtimeTurnHooksOnTransportError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + 499, + "client_closed_request", + "client realtime websocket disconnected before turn completed", + ) + if isNormalWebSocketClosure(err) { + return nil + } + return err + } + if disconnectAfterWrite { + return nil + } + } +} + +func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) { + defaultProvider := realtimeDefaultProviderForPath(path) + + switch { + case strings.TrimSpace(modelParam) != "": + provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider) + if provider == "" || strings.TrimSpace(model) == "" { + return "", "", errRealtimeModelFormat + } + return provider, strings.TrimSpace(model), nil + case strings.TrimSpace(deploymentParam) != "": + provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider) + if provider == "" || strings.TrimSpace(model) == "" { + return "", "", errRealtimeDeploymentFormat + } + return provider, strings.TrimSpace(model), nil + default: + return "", "", errRealtimeModelRequired + } +} + +func realtimeDefaultProviderForPath(path string) schemas.ModelProvider { + if strings.HasPrefix(path, "/openai/") { + return schemas.OpenAI + } + return "" +} + +func isNormalWebSocketClosure(err error) bool { + return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived) +} + +func isExpectedRealtimeRelayShutdown(err error) bool { + if err == nil { + return true + } + if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return true + } + // Relay teardown closes the opposite socket after the first side exits, which can + // surface as a plain network-close read error instead of a websocket close frame. + return strings.Contains(err.Error(), "use of closed network connection") +} + +func selectRealtimeRelayError(errs ...error) error { + for _, err := range errs { + if err != nil && !isExpectedRealtimeRelayShutdown(err) { + return err + } + } + return nil +} + +var ( + errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket") + errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket") + errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket") +) + +type realtimeClientConn struct { + conn *ws.Conn + writeMu sync.Mutex + closeOnce sync.Once + done chan struct{} +} + +func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn { + return &realtimeClientConn{ + conn: conn, + done: make(chan struct{}), + } +} + +func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) { + messageType, p, err = c.conn.ReadMessage() + if err == nil { + c.refreshReadDeadline() + } + return messageType, p, err +} + +func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil { + return err + } + if err := c.conn.WriteMessage(messageType, data); err != nil { + return err + } + return c.conn.SetWriteDeadline(time.Time{}) +} + +func (c *realtimeClientConn) startHeartbeat() { + c.installPongHandler() + c.refreshReadDeadline() + + go func() { + ticker := time.NewTicker(realtimeWSPingInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := c.writePing(); err != nil { + _ = c.Close() + return + } + case <-c.done: + return + } + } + }() +} + +func (c *realtimeClientConn) stopHeartbeat() { + c.closeDone() +} + +func (c *realtimeClientConn) installPongHandler() { + c.conn.SetPongHandler(func(string) error { + return c.refreshReadDeadline() + }) +} + +func (c *realtimeClientConn) refreshReadDeadline() error { + return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout)) +} + +func (c *realtimeClientConn) writePing() error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil { + return err + } + if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil { + return err + } + return c.conn.SetWriteDeadline(time.Time{}) +} + +func (c *realtimeClientConn) closeDone() { + c.closeOnce.Do(func() { + close(c.done) + }) +} + +func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) { + payload := newRealtimeTurnErrorEventPayload(bifrostErr) + _ = c.WriteMessage(ws.TextMessage, payload) +} + +func (c *realtimeClientConn) Close() error { + c.closeDone() + return c.conn.Close() +} + +const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key." + +// extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol +// header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.". +func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string { + header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol")) + for _, proto := range strings.Split(header, ",") { + proto = strings.TrimSpace(proto) + if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) { + return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix) + } + } + return "" +} + +func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError { + errType := code + return &schemas.BifrostError{ + StatusCode: &status, + Type: &errType, + Error: &schemas.ErrorField{ + Type: &errType, + Code: &errType, + Message: message, + }, + } +} diff --git a/transports/bifrost-http/handlers/wsresponses.go b/transports/bifrost-http/handlers/wsresponses.go index b36ba31123..ca293a116e 100644 --- a/transports/bifrost-http/handlers/wsresponses.go +++ b/transports/bifrost-http/handlers/wsresponses.go @@ -58,6 +58,14 @@ func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bf } } +// Close gracefully shuts down all active WebSocket responses sessions. +func (h *WSResponsesHandler) Close() { + if h == nil || h.sessions == nil { + return + } + h.sessions.CloseAll() +} + // RegisterRoutes registers the WebSocket Responses endpoint at the base path // and all OpenAI integration paths. func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { @@ -98,6 +106,7 @@ type authHeaders struct { virtualKey string apiKey string googAPIKey string + baggage string extraHeaders map[string]string } @@ -108,6 +117,7 @@ func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders { virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")), apiKey: string(ctx.Request.Header.Peek("x-api-key")), googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")), + baggage: string(ctx.Request.Header.Peek("baggage")), extraHeaders: make(map[string]string), } @@ -192,7 +202,7 @@ func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *a bifrostReq.Params.ExtraParams = extraParams } - bifrostCtx, cancel := h.createBifrostContext(auth) + bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth) if bifrostCtx == nil { writeWSError(session, 500, "server_error", "failed to create request context") return @@ -227,9 +237,10 @@ func (h *WSResponsesHandler) tryNativeWSUpstream( return false } - key, err := h.client.SelectKeyForProvider(ctx, req.Provider, req.Model) + key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model) if err != nil { - return false + writeWSError(session, 400, "invalid_request_error", err.Error()) + return true } wsURL := wsProvider.WebSocketResponsesURL(key) @@ -378,7 +389,7 @@ func parseUpstreamWSEvent(data []byte, provider schemas.ModelProvider, model str } streamResp.ExtraFields.RequestType = schemas.ResponsesStreamRequest streamResp.ExtraFields.Provider = provider - streamResp.ExtraFields.ModelRequested = model + streamResp.ExtraFields.OriginalModelRequested = model return &streamResp } @@ -495,10 +506,14 @@ func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketRespo }, nil } -// createBifrostContext builds a BifrostContext from the auth headers captured during upgrade. -func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) { +// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade. +func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) { ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background()) + if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" { + ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID) + } + if auth.virtualKey != "" { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey) } @@ -508,12 +523,12 @@ func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.B if strings.HasPrefix(auth.authorization, "Bearer ") { token := strings.TrimPrefix(auth.authorization, "Bearer ") if strings.HasPrefix(token, "sk-bf-") { - ctx.SetValue(schemas.BifrostContextKeyVirtualKey, token) - } else if h.handlerStore.ShouldAllowDirectKeys() { + ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-")) + } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(token), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, } ctx.SetValue(schemas.BifrostContextKeyDirectKey, key) @@ -523,11 +538,11 @@ func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.B if auth.apiKey != "" { if strings.HasPrefix(auth.apiKey, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-")) - } else if h.handlerStore.ShouldAllowDirectKeys() { + } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(auth.apiKey), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, } ctx.SetValue(schemas.BifrostContextKeyDirectKey, key) @@ -536,11 +551,11 @@ func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.B if auth.googAPIKey != "" { if strings.HasPrefix(auth.googAPIKey, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-")) - } else if h.handlerStore.ShouldAllowDirectKeys() { + } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(auth.googAPIKey), - Models: []string{}, + Models: schemas.WhiteList{"*"}, Weight: 1.0, } ctx.SetValue(schemas.BifrostContextKeyDirectKey, key) diff --git a/transports/bifrost-http/handlers/wsresponses_test.go b/transports/bifrost-http/handlers/wsresponses_test.go new file mode 100644 index 0000000000..aad3b15e9c --- /dev/null +++ b/transports/bifrost-http/handlers/wsresponses_test.go @@ -0,0 +1,68 @@ +package handlers + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" +) + +type testWSHandlerStore struct { + allowDirectKeys bool +} + +func (s testWSHandlerStore) ShouldAllowDirectKeys() bool { + return s.allowDirectKeys +} + +func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { + return nil +} + +func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider { + return nil +} + +func (s testWSHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor { + return nil +} + +func (s testWSHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { + return nil +} + +func (s testWSHandlerStore) GetAsyncJobResultTTL() int { + return 0 +} + +func (s testWSHandlerStore) GetKVStore() *kvstore.Store { + return nil +} + +func (s testWSHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + return nil +} + +func TestCreateBifrostContextFromAuth_BaggageSessionIDSetsGrouping(t *testing.T) { + ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{ + baggage: "foo=bar, session-id=rt-ws-123, baz=qux", + }) + defer cancel() + + if got, _ := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-ws-123" { + t.Fatalf("parent request id = %q, want %q", got, "rt-ws-123") + } +} + +func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T) { + ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{ + baggage: "session-id= ", + }) + defer cancel() + + if got := ctx.Value(schemas.BifrostContextKeyParentRequestID); got != nil { + t.Fatalf("parent request id should be unset, got %#v", got) + } +} diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index 25033b4928..7d3fdd9b22 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -43,7 +43,7 @@ func createAnthropicCompleteRouteConfig(pathPrefix string) RouteConfig { return nil, errors.New("invalid request type") }, TextResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { - if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { + if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.OriginalModelRequested, resp.ExtraFields.ResolvedModelUsed) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -85,7 +85,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger return nil, errors.New("invalid request type") }, ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { - if isClaudeModel(resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment, string(resp.ExtraFields.Provider)) { + if isClaudeModel(resp.ExtraFields.OriginalModelRequested, resp.ExtraFields.ResolvedModelUsed, string(resp.ExtraFields.Provider)) { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil } @@ -113,7 +113,7 @@ func createAnthropicMessagesRouteConfig(pathPrefix string, logger schemas.Logger }, StreamConfig: &StreamConfig{ ResponsesStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) { - if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.ModelRequested, resp.ExtraFields.ModelDeployment) { + if shouldUsePassthrough(ctx, resp.ExtraFields.Provider, resp.ExtraFields.OriginalModelRequested, resp.ExtraFields.ResolvedModelUsed) { if resp.ExtraFields.RawResponse != nil { raw, ok := resp.ExtraFields.RawResponse.(string) if !ok { @@ -396,15 +396,15 @@ func checkAnthropicPassthrough(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.Bif } // shouldUsePassthrough checks if the request should be sent to the passthrough endpoint. -func shouldUsePassthrough(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, deployment string) bool { - return anthropic.IsClaudeCodeRequest(ctx) && isClaudeModel(model, deployment, string(provider)) +func shouldUsePassthrough(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, alias string) bool { + return anthropic.IsClaudeCodeRequest(ctx) && isClaudeModel(model, alias, string(provider)) } -func isClaudeModel(model, deployment, provider string) bool { +func isClaudeModel(model, alias, provider string) bool { return (provider == string(schemas.Anthropic) || - (provider == "" && schemas.IsAnthropicModel(model))) || - (provider == string(schemas.Vertex) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(deployment))) || - (provider == string(schemas.Azure) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(deployment))) + (provider == "" && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(alias)))) || + (provider == string(schemas.Vertex) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(alias))) || + (provider == string(schemas.Azure) && (schemas.IsAnthropicModel(model) || schemas.IsAnthropicModel(alias))) } // extractAnthropicListModelsParams extracts query parameters for list models request diff --git a/transports/bifrost-http/integrations/bedrock.go b/transports/bifrost-http/integrations/bedrock.go index 38ad98fbc6..b2dcb4174a 100644 --- a/transports/bifrost-http/integrations/bedrock.go +++ b/transports/bifrost-http/integrations/bedrock.go @@ -121,15 +121,21 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler Path: pathPrefix + "/model/{modelId}/invoke-with-response-stream", Method: "POST", GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { - return bedrock.DetectInvokeRequestType(ctx.Request.Body()) + modelID, _ := ctx.UserValue("modelId").(string) + return bedrock.DetectInvokeRequestType(ctx.Request.Body(), modelID) }, GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockInvokeRequest{} }, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { if invokeReq, ok := req.(*bedrock.BedrockInvokeRequest); ok { + requestType, _ := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType) + switch requestType { + case schemas.EmbeddingRequest, schemas.ImageGenerationRequest, schemas.ImageEditRequest, schemas.ImageVariationRequest: + return nil, fmt.Errorf("request type %v is not supported on invoke-with-response-stream", requestType) + } invokeReq.Stream = true - if invokeReq.IsMessagesRequest() { + if requestType == schemas.ResponsesRequest { // Messages-based β†’ Responses path (streaming) converseReq := invokeReq.ToBedrockConverseRequest() responsesReq, err := converseReq.ToBifrostResponsesRequest(ctx) @@ -176,37 +182,72 @@ func createBedrockInvokeWithResponseStreamRouteConfig(pathPrefix string, handler // createBedrockInvokeRouteConfig creates a route configuration for the Bedrock Invoke API endpoint // Handles POST /bedrock/model/{modelId}/invoke // Uses BedrockInvokeRequest as a union type that supports all model families. -// Messages-based requests (Anthropic Messages, Nova, AI21) are routed through the Responses path, -// while prompt-based requests (Anthropic legacy, Mistral, Llama, Cohere) go through Text Completion. +// Request type is detected from the body + model ID and dispatched accordingly: +// - Embedding (Titan inputText, Cohere texts) +// - ImageGeneration (taskType=TEXT_IMAGE, Stability AI and other providers prompt-only) +// - ImageEdit (taskType=INPAINTING/OUTPAINTING/BACKGROUND_REMOVAL, Stability AI image+prompt) +// - ImageVariation (taskType=IMAGE_VARIATION) +// - ResponsesRequest (messages array β€” Anthropic Messages, Nova, AI21) +// - TextCompletionRequest (prompt β€” Anthropic legacy, Mistral, Llama, Cohere) func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerStore) RouteConfig { return RouteConfig{ Type: RouteConfigTypeBedrock, Path: pathPrefix + "/model/{modelId}/invoke", Method: "POST", GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType { - return bedrock.DetectInvokeRequestType(ctx.Request.Body()) + modelID, _ := ctx.UserValue("modelId").(string) + return bedrock.DetectInvokeRequestType(ctx.Request.Body(), modelID) }, GetRequestTypeInstance: func(ctx context.Context) interface{} { return &bedrock.BedrockInvokeRequest{} }, RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) { - if invokeReq, ok := req.(*bedrock.BedrockInvokeRequest); ok { - if invokeReq.IsMessagesRequest() { - // Messages-based (Anthropic Messages, Nova, AI21) β†’ Responses path - converseReq := invokeReq.ToBedrockConverseRequest() - responsesReq, err := converseReq.ToBifrostResponsesRequest(ctx) - if err != nil { - return nil, fmt.Errorf("failed to convert invoke messages request: %w", err) - } - return &schemas.BifrostRequest{ResponsesRequest: responsesReq}, nil + invokeReq, ok := req.(*bedrock.BedrockInvokeRequest) + if !ok { + return nil, errors.New("invalid request type") + } + + requestType, _ := ctx.Value(schemas.BifrostContextKeyHTTPRequestType).(schemas.RequestType) + switch requestType { + case schemas.EmbeddingRequest: + return &schemas.BifrostRequest{ + EmbeddingRequest: invokeReq.ToBifrostEmbeddingRequest(ctx), + }, nil + + case schemas.ImageGenerationRequest: + return &schemas.BifrostRequest{ + ImageGenerationRequest: invokeReq.ToBifrostImageGenerationRequest(ctx), + }, nil + + case schemas.ImageEditRequest: + editReq, err := invokeReq.ToBifrostImageEditRequest(ctx) + if err != nil { + return nil, fmt.Errorf("failed to convert invoke image edit request: %w", err) } - // Prompt-based (Anthropic legacy, Mistral, Llama, Cohere) β†’ Text Completion path - // Also handles Cohere Command R (message β†’ prompt conversion) + return &schemas.BifrostRequest{ImageEditRequest: editReq}, nil + + case schemas.ImageVariationRequest: + varReq, err := invokeReq.ToBifrostImageVariationRequest(ctx) + if err != nil { + return nil, fmt.Errorf("failed to convert invoke image variation request: %w", err) + } + return &schemas.BifrostRequest{ImageVariationRequest: varReq}, nil + + case schemas.ResponsesRequest: + // Messages-based (Anthropic Messages, Nova, AI21) -> Responses path + converseReq := invokeReq.ToBedrockConverseRequest() + responsesReq, err := converseReq.ToBifrostResponsesRequest(ctx) + if err != nil { + return nil, fmt.Errorf("failed to convert invoke messages request: %w", err) + } + return &schemas.BifrostRequest{ResponsesRequest: responsesReq}, nil + + default: + // TextCompletionRequest and any unrecognised type forwarded to text completion path return &schemas.BifrostRequest{ TextCompletionRequest: invokeReq.ToBifrostTextCompletionRequest(ctx), }, nil } - return nil, errors.New("invalid request type") }, TextResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTextCompletionResponse) (interface{}, error) { return bedrock.ToBedrockTextCompletionResponse(resp), nil @@ -214,6 +255,12 @@ func createBedrockInvokeRouteConfig(pathPrefix string, handlerStore lib.HandlerS ResponsesResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) { return bedrock.ToBedrockInvokeMessagesResponse(ctx, resp) }, + EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + return bedrock.ToBedrockEmbeddingInvokeResponse(resp) + }, + ImageGenerationResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostImageGenerationResponse) (interface{}, error) { + return bedrock.ToBedrockInvokeImagesResponse(ctx, resp) + }, ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} { return bedrock.ToBedrockError(err) }, diff --git a/transports/bifrost-http/integrations/bedrock_test.go b/transports/bifrost-http/integrations/bedrock_test.go index b2050471a1..16ab7ad4d5 100644 --- a/transports/bifrost-http/integrations/bedrock_test.go +++ b/transports/bifrost-http/integrations/bedrock_test.go @@ -16,9 +16,10 @@ import ( // mockHandlerStore implements lib.HandlerStore for testing type mockHandlerStore struct { - allowDirectKeys bool - headerMatcher *lib.HeaderMatcher - availableProviders []schemas.ModelProvider + allowDirectKeys bool + headerMatcher *lib.HeaderMatcher + availableProviders []schemas.ModelProvider + mcpHeaderCombinedAllowlist schemas.WhiteList } func (m *mockHandlerStore) ShouldAllowDirectKeys() bool { @@ -49,6 +50,10 @@ func (m *mockHandlerStore) GetKVStore() *kvstore.Store { return nil } +func (m *mockHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + return m.mcpHeaderCombinedAllowlist +} + // Ensure mockHandlerStore implements lib.HandlerStore var _ lib.HandlerStore = (*mockHandlerStore)(nil) diff --git a/transports/bifrost-http/integrations/cursor.go b/transports/bifrost-http/integrations/cursor.go index a4ad12bc33..29513c1c05 100644 --- a/transports/bifrost-http/integrations/cursor.go +++ b/transports/bifrost-http/integrations/cursor.go @@ -104,10 +104,10 @@ func cursorChunkID(extras *schemas.BifrostResponseExtraFields) string { // cursorModel returns the best model name available from extra fields. func cursorModel(extras *schemas.BifrostResponseExtraFields) string { - if extras.ModelDeployment != "" { - return extras.ModelDeployment + if extras.ResolvedModelUsed != "" { + return extras.ResolvedModelUsed } - return extras.ModelRequested + return extras.OriginalModelRequested } // convertResponsesStreamToChatChunk maps a Responses API stream event to a diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 6bc6121f06..08959c8d24 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -279,16 +279,13 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ key := schemas.Key{ ID: uuid.New().String(), - Models: []string{}, + Models: schemas.WhiteList{"*"}, AzureKeyConfig: &schemas.AzureKeyConfig{}, } if deploymentEndpointStr != "" && deploymentIDStr != "" && azureKeyStr != "" { key.Value = *schemas.NewEnvVar(strings.TrimPrefix(azureKeyStr, "Bearer ")) key.AzureKeyConfig.Endpoint = *schemas.NewEnvVar(deploymentEndpointStr) - key.AzureKeyConfig.Deployments = map[string]string{ - deploymentIDStr: deploymentIDStr, - } } if apiVersionStr != "" { @@ -459,6 +456,9 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return resp, nil }, TranscriptionResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) { + if schemas.IsPlainTextTranscriptionFormat(resp.ResponseFormat) { + return []byte(resp.Text), nil + } if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -693,7 +693,6 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return &schemas.BifrostRequest{ ResponsesRequest: openaiReq.ToBifrostResponsesRequest(ctx), }, nil - } return nil, errors.New("invalid request type") }, @@ -899,6 +898,9 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return nil, errors.New("invalid transcription request type") }, TranscriptionResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostTranscriptionResponse) (interface{}, error) { + if schemas.IsPlainTextTranscriptionFormat(resp.ResponseFormat) { + return []byte(resp.Text), nil + } if resp.ExtraFields.Provider == schemas.OpenAI { if resp.ExtraFields.RawResponse != nil { return resp.ExtraFields.RawResponse, nil @@ -2429,7 +2431,6 @@ func extractContainerListQueryParams(_ lib.HandlerStore) PreRequestCallback { // extractContainerIDFromPath extracts container_id from path parameters and provider from query params func extractContainerIDFromPath(_ lib.HandlerStore) PreRequestCallback { return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { - containerID := ctx.UserValue("container_id") if containerID == nil { return errors.New("container_id is required") @@ -2678,7 +2679,6 @@ func extractContainerFileCreateParams(_ lib.HandlerStore) PreRequestCallback { // extractContainerFileListQueryParams extracts query parameters for container file list requests func extractContainerFileListQueryParams(_ lib.HandlerStore) PreRequestCallback { return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { - containerID := ctx.UserValue("container_id") if containerID == nil { return errors.New("container_id is required") @@ -2725,7 +2725,6 @@ func extractContainerFileListQueryParams(_ lib.HandlerStore) PreRequestCallback // extractContainerAndFileIDFromPath extracts container_id and file_id from path parameters and provider from query params func extractContainerAndFileIDFromPath(handlerStore lib.HandlerStore) PreRequestCallback { return func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error { - containerID := ctx.UserValue("container_id") if containerID == nil { return errors.New("container_id is required") @@ -2803,6 +2802,35 @@ func OpenAIRealtimePaths(pathPrefix string) []string { return paths } +// OpenAIRealtimeWebRTCCallsPaths returns HTTP POST paths for the GA /realtime/calls +// WebRTC SDP exchange endpoint (multipart sdp + session format). +func OpenAIRealtimeWebRTCCallsPaths(pathPrefix string) []string { + basePaths := []string{ + "/v1/realtime/calls", + "/realtime/calls", + "/openai/realtime/calls", + } + paths := make([]string, 0, len(basePaths)) + for _, p := range basePaths { + paths = append(paths, pathPrefix+p) + } + return paths +} + +// OpenAIRealtimeClientSecretPaths returns HTTP POST paths for OpenAI-compatible +// realtime client secret creation aliases. +func OpenAIRealtimeClientSecretPaths(pathPrefix string) []string { + basePaths := []string{ + "/v1/realtime/client_secrets", + "/v1/realtime/sessions", + } + paths := make([]string, 0, len(basePaths)) + for _, p := range basePaths { + paths = append(paths, pathPrefix+p) + } + return paths +} + // NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client. func NewOpenAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *OpenAIRouter { routes := CreateOpenAIRouteConfigs("/openai", handlerStore) diff --git a/transports/bifrost-http/integrations/router.go b/transports/bifrost-http/integrations/router.go index 327781bb0b..ad915b0ebc 100644 --- a/transports/bifrost-http/integrations/router.go +++ b/transports/bifrost-http/integrations/router.go @@ -610,7 +610,7 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle var rawBody []byte // Execute the request through Bifrost - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher(), g.handlerStore.GetMCPHeaderCombinedAllowlist()) // Set integration type to context bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, string(config.Type)) @@ -1069,6 +1069,19 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf // Convert Bifrost response to integration-specific format and send response, err = config.TranscriptionResponseConverter(bifrostCtx, transcriptionResponse) providerResponseHeaders = transcriptionResponse.ExtraFields.ProviderResponseHeaders + + // If converter returns raw bytes, write directly with provider headers. + // Used for plain-text transcription formats (text, srt, vtt). + if err == nil { + if rawBytes, ok := response.([]byte); ok { + for key, value := range providerResponseHeaders { + ctx.Response.Header.Set(key, value) + } + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetBody(rawBytes) + return + } + } case bifrostReq.ImageGenerationRequest != nil: imageGenerationResponse, bifrostErr := g.client.ImageGenerationRequest(bifrostCtx, bifrostReq.ImageGenerationRequest) if bifrostErr != nil { @@ -1714,7 +1727,6 @@ func (g *GenericRouter) handleBatchRequest(ctx *fasthttp.RequestCtx, config Rout // handleFileRequest handles file API requests (upload, list, retrieve, delete, content) func (g *GenericRouter) handleFileRequest(ctx *fasthttp.RequestCtx, config RouteConfig, req interface{}, fileReq *FileRequest, bifrostCtx *schemas.BifrostContext) { - var response interface{} var err error @@ -2636,7 +2648,7 @@ func (g *GenericRouter) handlePassthrough(ctx *fasthttp.RequestCtx) { return true }) - bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher()) + bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, g.handlerStore.ShouldAllowDirectKeys(), g.handlerStore.GetHeaderMatcher(), g.handlerStore.GetMCPHeaderCombinedAllowlist()) if directKey := ctx.UserValue(string(schemas.BifrostContextKeyDirectKey)); directKey != nil { if key, ok := directKey.(schemas.Key); ok { bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index bd458ffc2b..6a721e5916 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -41,6 +41,7 @@ import ( "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" + "github.com/maximhq/bifrost/plugins/prompts" "github.com/maximhq/bifrost/plugins/semanticcache" "github.com/maximhq/bifrost/plugins/telemetry" "gorm.io/gorm" @@ -76,6 +77,8 @@ type HandlerStore interface { // GetKVStore returns the shared in-memory kvstore instance. // Returns nil if not initialized. GetKVStore() *kvstore.Store + // GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers + GetMCPHeaderCombinedAllowlist() schemas.WhiteList } // Retry backoff constants for validation @@ -101,6 +104,7 @@ func getWeight(w *float64) float64 { // IsBuiltinPlugin checks if a plugin is a built-in plugin func IsBuiltinPlugin(name string) bool { return name == telemetry.PluginName || + name == prompts.PluginName || name == logging.PluginName || name == governance.PluginName || name == litellmcompat.PluginName || @@ -172,56 +176,7 @@ func (cd *ConfigData) UnmarshalJSON(data []byte) error { if cd.Providers == nil { cd.Providers = make(map[string]configstore.ProviderConfig) } - // Extract provider configs from virtual keys. - // Keys can be either full definitions (with value) or references (name only). - // References are resolved by looking up the key by name from the providers section. - // NOTE: Only FULL key definitions (with Value) should be added to the provider. - // Reference lookups are for virtual key resolution only - they should NOT be added - // back to the provider since they already exist there. - if cd.Governance != nil && cd.Governance.VirtualKeys != nil { - for _, virtualKey := range cd.Governance.VirtualKeys { - if virtualKey.ProviderConfigs != nil { - for _, providerConfig := range virtualKey.ProviderConfigs { - // Only collect keys with Value (full definitions) to add to provider - var keysToAddToProvider []schemas.Key - for _, tableKey := range providerConfig.Keys { - if tableKey.Value.GetValue() != "" { - // Full key definition - add to provider - keysToAddToProvider = append(keysToAddToProvider, schemas.Key{ - ID: tableKey.KeyID, - Name: tableKey.Name, - Value: tableKey.Value, - Models: tableKey.Models, - BlacklistedModels: tableKey.BlacklistedModels, - Weight: getWeight(tableKey.Weight), - Enabled: tableKey.Enabled, - UseForBatchAPI: tableKey.UseForBatchAPI, - AzureKeyConfig: tableKey.AzureKeyConfig, - VertexKeyConfig: tableKey.VertexKeyConfig, - BedrockKeyConfig: tableKey.BedrockKeyConfig, - ReplicateKeyConfig: tableKey.ReplicateKeyConfig, - VLLMKeyConfig: tableKey.VLLMKeyConfig, - ConfigHash: tableKey.ConfigHash, - }) - } - // Reference lookups (no Value) are NOT added to provider - they already exist there - } - // Merge or create provider entry - only for full key definitions - if len(keysToAddToProvider) > 0 { - if existing, ok := cd.Providers[providerConfig.Provider]; ok { - existing.Keys = append(existing.Keys, keysToAddToProvider...) - cd.Providers[providerConfig.Provider] = existing - } else { - cd.Providers[providerConfig.Provider] = configstore.ProviderConfig{ - Keys: keysToAddToProvider, - } - } - } - } - } - } - } // Parse VectorStoreConfig using its internal unmarshaler if len(temp.VectorStoreConfig) > 0 { var vectorStoreConfig vectorstore.Config @@ -356,6 +311,7 @@ var DefaultClientConfig = configstore.ClientConfig{ MCPCodeModeBindingLevel: string(schemas.CodeModeBindingLevelServer), EnableLiteLLMFallbacks: false, HideDeletedVirtualKeysInFilters: false, + RoutingChainMaxDepth: governance.DefaultRoutingChainMaxDepth, } // LoadConfig loads initial configuration from a JSON config file into memory @@ -615,6 +571,9 @@ func applyClientConfigDefaults(cc *configstore.ClientConfig) { if cc.MCPAgentDepth == 0 { cc.MCPAgentDepth = DefaultClientConfig.MCPAgentDepth } + if cc.RoutingChainMaxDepth == 0 { + cc.RoutingChainMaxDepth = DefaultClientConfig.RoutingChainMaxDepth + } if cc.MCPToolExecutionTimeout == 0 { cc.MCPToolExecutionTimeout = DefaultClientConfig.MCPToolExecutionTimeout } @@ -760,6 +719,9 @@ func processProvider( if providerKeyInFile.ID == "" { providerCfgInFile.Keys[i].ID = uuid.NewString() } + if err := providerKeyInFile.Aliases.Validate(); err != nil { + return fmt.Errorf("invalid aliases for key %q in provider %s: %w", providerKeyInFile.Name, provider, err) + } } // Generate hash from config.json provider config fileProviderConfigHash, err := providerCfgInFile.GenerateConfigHash(string(provider)) @@ -843,7 +805,10 @@ func mergeProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []schema VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + Aliases: dbKey.Aliases, VLLMKeyConfig: dbKey.VLLMKeyConfig, + OllamaKeyConfig: dbKey.OllamaKeyConfig, + SGLKeyConfig: dbKey.SGLKeyConfig, Enabled: dbKey.Enabled, UseForBatchAPI: dbKey.UseForBatchAPI, }) @@ -921,7 +886,10 @@ func reconcileProviderKeys(provider schemas.ModelProvider, fileKeys, dbKeys []sc VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, ReplicateKeyConfig: dbKey.ReplicateKeyConfig, + Aliases: dbKey.Aliases, VLLMKeyConfig: dbKey.VLLMKeyConfig, + OllamaKeyConfig: dbKey.OllamaKeyConfig, + SGLKeyConfig: dbKey.SGLKeyConfig, Enabled: dbKey.Enabled, UseForBatchAPI: dbKey.UseForBatchAPI, }) @@ -1091,6 +1059,8 @@ func loadGovernanceConfig(ctx context.Context, config *Config, configData *Confi logger.Debug("no governance config found in store, processing from config file") config.GovernanceConfig = configData.Governance createGovernanceConfigInStore(ctx, config) + // Pricing overrides are loaded into ModelCatalog after initFrameworkConfig, + // once ModelCatalog is initialized. } else { logger.Debug("no governance config in store or config file") } @@ -1329,6 +1299,45 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf routingRulesToAdd = append(routingRulesToAdd, configData.Governance.RoutingRules[i]) } } + // Merge PricingOverrides by ID with hash comparison + pricingOverridesToAdd := make([]configstoreTables.TablePricingOverride, 0) + pricingOverridesToUpdate := make([]configstoreTables.TablePricingOverride, 0) + for i, newOverride := range configData.Governance.PricingOverrides { + if len(newOverride.RequestTypes) > 0 { + b, err := json.Marshal(newOverride.RequestTypes) + if err != nil { + logger.Warn("failed to serialize request_types for pricing override %s: %v", newOverride.ID, err) + continue + } + configData.Governance.PricingOverrides[i].RequestTypesJSON = string(b) + } else { + configData.Governance.PricingOverrides[i].RequestTypesJSON = "[]" + } + fileHash, err := configstore.GeneratePricingOverrideHash(configData.Governance.PricingOverrides[i]) + if err != nil { + logger.Warn("failed to generate pricing override hash for %s: %v", newOverride.ID, err) + continue + } + configData.Governance.PricingOverrides[i].ConfigHash = fileHash + + found := false + for j, existing := range governanceConfig.PricingOverrides { + if existing.ID == newOverride.ID { + found = true + if existing.ConfigHash != fileHash { + logger.Debug("config hash mismatch for pricing override %s, syncing from config file", newOverride.ID) + pricingOverridesToUpdate = append(pricingOverridesToUpdate, configData.Governance.PricingOverrides[i]) + governanceConfig.PricingOverrides[j] = configData.Governance.PricingOverrides[i] + } else { + logger.Debug("config hash matches for pricing override %s, keeping DB config", newOverride.ID) + } + break + } + } + if !found { + pricingOverridesToAdd = append(pricingOverridesToAdd, configData.Governance.PricingOverrides[i]) + } + } // Add merged items to config config.GovernanceConfig.Budgets = append(governanceConfig.Budgets, budgetsToAdd...) config.GovernanceConfig.RateLimits = append(governanceConfig.RateLimits, rateLimitsToAdd...) @@ -1336,13 +1345,15 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf config.GovernanceConfig.Teams = append(governanceConfig.Teams, teamsToAdd...) config.GovernanceConfig.VirtualKeys = append(governanceConfig.VirtualKeys, virtualKeysToAdd...) config.GovernanceConfig.RoutingRules = append(governanceConfig.RoutingRules, routingRulesToAdd...) + config.GovernanceConfig.PricingOverrides = append(governanceConfig.PricingOverrides, pricingOverridesToAdd...) // Update store with merged config items hasChanges := len(budgetsToAdd) > 0 || len(budgetsToUpdate) > 0 || len(rateLimitsToAdd) > 0 || len(rateLimitsToUpdate) > 0 || len(customersToAdd) > 0 || len(customersToUpdate) > 0 || len(teamsToAdd) > 0 || len(teamsToUpdate) > 0 || len(virtualKeysToAdd) > 0 || len(virtualKeysToUpdate) > 0 || - len(routingRulesToAdd) > 0 || len(routingRulesToUpdate) > 0 + len(routingRulesToAdd) > 0 || len(routingRulesToUpdate) > 0 || + len(pricingOverridesToAdd) > 0 || len(pricingOverridesToUpdate) > 0 if config.ConfigStore != nil && hasChanges { err := updateGovernanceConfigInStore(ctx, config, budgetsToAdd, budgetsToUpdate, @@ -1350,11 +1361,28 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf customersToAdd, customersToUpdate, teamsToAdd, teamsToUpdate, virtualKeysToAdd, virtualKeysToUpdate, - routingRulesToAdd, routingRulesToUpdate) + routingRulesToAdd, routingRulesToUpdate, + pricingOverridesToAdd, pricingOverridesToUpdate) if err != nil { logger.Fatal("failed to sync governance config: %v", err) } } + // Sync pricing overrides into the model catalog in one batch to avoid + // rebuilding the lookup map on every iteration. + if config.ModelCatalog != nil { + rows := make([]*configstoreTables.TablePricingOverride, 0, len(pricingOverridesToAdd)+len(pricingOverridesToUpdate)) + for i := range pricingOverridesToAdd { + rows = append(rows, &pricingOverridesToAdd[i]) + } + for i := range pricingOverridesToUpdate { + rows = append(rows, &pricingOverridesToUpdate[i]) + } + if len(rows) > 0 { + if err := config.ModelCatalog.UpsertPricingOverrides(rows...); err != nil { + logger.Error("failed to upsert pricing overrides into model catalog: %v", err) + } + } + } } // updateGovernanceConfigInStore updates governance config items in the store @@ -1373,6 +1401,8 @@ func updateGovernanceConfigInStore( virtualKeysToUpdate []configstoreTables.TableVirtualKey, routingRulesToAdd []configstoreTables.TableRoutingRule, routingRulesToUpdate []configstoreTables.TableRoutingRule, + pricingOverridesToAdd []configstoreTables.TablePricingOverride, + pricingOverridesToUpdate []configstoreTables.TablePricingOverride, ) error { logger.Debug("updating governance config in store with merged items") return config.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { @@ -1484,6 +1514,20 @@ func updateGovernanceConfigInStore( } } + // Create pricing overrides (new from config.json) + for _, override := range pricingOverridesToAdd { + if err := config.ConfigStore.CreatePricingOverride(ctx, &override, tx); err != nil { + return fmt.Errorf("failed to create pricing override %s: %w", override.ID, err) + } + } + + // Update pricing overrides (config.json changed) + for _, override := range pricingOverridesToUpdate { + if err := config.ConfigStore.UpdatePricingOverride(ctx, &override, tx); err != nil { + return fmt.Errorf("failed to update pricing override %s: %w", override.ID, err) + } + } + return nil }) } @@ -1608,6 +1652,29 @@ func createGovernanceConfigInStore(ctx context.Context, config *Config) { virtualKey.MCPConfigs = mcpConfigs } + // Create pricing overrides after virtual keys so that scoped overrides referencing + // a virtual key ID are inserted after the VK row exists. + for i := range config.GovernanceConfig.PricingOverrides { + override := &config.GovernanceConfig.PricingOverrides[i] + if len(override.RequestTypes) > 0 { + b, err := json.Marshal(override.RequestTypes) + if err != nil { + return fmt.Errorf("failed to serialize request_types for pricing override %s: %w", override.ID, err) + } + override.RequestTypesJSON = string(b) + } else { + override.RequestTypesJSON = "[]" + } + overrideHash, err := configstore.GeneratePricingOverrideHash(*override) + if err != nil { + return fmt.Errorf("failed to generate pricing override hash for %s: %w", override.ID, err) + } + override.ConfigHash = overrideHash + if err := config.ConfigStore.CreatePricingOverride(ctx, override, tx); err != nil { + return fmt.Errorf("failed to create pricing override %s: %w", override.ID, err) + } + } + return nil }); err != nil { logger.Warn("failed to update governance config: %v", err) @@ -1879,23 +1946,6 @@ func mergePlugins(ctx context.Context, config *Config, configData *ConfigData) { } } -// convertSchemasMCPClientConfigToTable converts schemas.MCPClientConfig to tables.TableMCPClient -func convertSchemasMCPClientConfigToTable(clientConfig *schemas.MCPClientConfig) *configstoreTables.TableMCPClient { - return &configstoreTables.TableMCPClient{ - ClientID: clientConfig.ID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: string(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, - Headers: clientConfig.Headers, - AuthType: string(clientConfig.AuthType), - OauthConfigID: clientConfig.OauthConfigID, - } -} - // buildMCPPricingDataFromStore builds MCP pricing data from the config store func buildMCPPricingDataFromStore(ctx context.Context, configStore configstore.ConfigStore) mcpcatalog.MCPPricingData { mcpPricingData := mcpcatalog.MCPPricingData{} @@ -2051,17 +2101,33 @@ func initFrameworkConfig(ctx context.Context, config *Config, configData *Config logger.Error("failed to initialize pricing manager: %v", err) } else { config.ModelCatalog = pricingManager - applyProviderPricingOverrides(config.ModelCatalog, config.Providers) } // Initialize MCP catalog - mcpCatalog, err := mcpcatalog.Init(ctx, &mcpcatalog.Config{ - PricingData: buildMCPPricingDataFromConfig(ctx, configData), - }, logger) + // Merge file-based pricing into mcpPricingConfig (DB data already loaded above). + // File config is used as fallback; DB values take precedence via the merge order. + if mcpPricingConfig.PricingData == nil { + mcpPricingConfig.PricingData = mcpcatalog.MCPPricingData{} + } + for k, v := range buildMCPPricingDataFromConfig(ctx, configData) { + if _, exists := mcpPricingConfig.PricingData[k]; !exists { + mcpPricingConfig.PricingData[k] = v + } + } + mcpCatalog, err := mcpcatalog.Init(ctx, mcpPricingConfig, logger) if err != nil { logger.Warn("failed to initialize MCP catalog: %v", err) } config.MCPCatalog = mcpCatalog + + // ModelCatalog is now initialized; replay pricing overrides for the no-store path. + // loadGovernanceConfig ran before ModelCatalog existed, so the in-memory + // load was skipped. Do it here now that ModelCatalog is available. + if config.ModelCatalog != nil && config.GovernanceConfig != nil && len(config.GovernanceConfig.PricingOverrides) > 0 { + if err := config.ModelCatalog.SetPricingOverrides(config.GovernanceConfig.PricingOverrides); err != nil { + logger.Warn("failed to set pricing overrides from config file: %v", err) + } + } } // initEncryption initializes encryption from config data or environment variables. @@ -2183,7 +2249,6 @@ func reconcileVirtualKeyAssociations( // Update existing provider config from file existing.Weight = newPC.Weight existing.AllowedModels = newPC.AllowedModels - existing.BudgetID = newPC.BudgetID existing.RateLimitID = newPC.RateLimitID existing.Keys = newPC.Keys if err := store.UpdateVirtualKeyProviderConfig(ctx, &existing, tx); err != nil { @@ -2324,6 +2389,114 @@ func (c *Config) SetHeaderMatcher(m *HeaderMatcher) { c.headerMatcher.Store(m) } +// GetMCPHeaderCombinedAllowlist returns the combined allowlist for MCP headers across all MCP clients. +// This method acquires a muMCP read lock and is safe for concurrent access from hot paths. +func (c *Config) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + c.muMCP.RLock() + defer c.muMCP.RUnlock() + + if c.MCPConfig == nil || len(c.MCPConfig.ClientConfigs) == 0 { + return schemas.WhiteList{} + } + + allowlist := schemas.WhiteList{} + for _, mcpClient := range c.MCPConfig.ClientConfigs { + if mcpClient == nil { + continue + } + if mcpClient.AllowedExtraHeaders.IsUnrestricted() { + return schemas.WhiteList{"*"} + } + allowlist = append(allowlist, mcpClient.AllowedExtraHeaders...) + } + return allowlist +} + +// GetAllowOnAllVirtualKeysClients returns a map of clientID -> clientName for all MCP clients +// that have AllowOnAllVirtualKeys enabled. The returned map is a copy, safe for concurrent use. +func (c *Config) GetAllowOnAllVirtualKeysClients() map[string]string { + c.muMCP.RLock() + defer c.muMCP.RUnlock() + + if c.MCPConfig == nil { + return nil + } + result := make(map[string]string) + for _, client := range c.MCPConfig.ClientConfigs { + if client != nil && client.AllowOnAllVirtualKeys { + result[client.ID] = client.Name + } + } + return result +} + +// GetPerUserOAuthMCPClients returns a map of clientID -> clientName for all MCP clients +// that have AuthType set to "per_user_oauth". The returned map is a copy, safe for concurrent use. +func (c *Config) GetPerUserOAuthMCPClients() map[string]string { + c.muMCP.RLock() + defer c.muMCP.RUnlock() + + if c.MCPConfig == nil { + return nil + } + result := make(map[string]string) + for _, client := range c.MCPConfig.ClientConfigs { + if client != nil && client.AuthType == schemas.MCPAuthTypePerUserOauth { + result[client.ID] = client.Name + } + } + return result +} + +// GetPerUserOAuthMCPClientsForVirtualKey returns a map of clientID -> clientName for +// per_user_oauth MCP clients that the given VK is allowed to use. A client is included if: +// - AllowOnAllVirtualKeys is true, OR +// - The VK has an explicit entry in governance_virtual_key_mcp_configs for that client. +// +// If virtualKeyID is empty, all per-user OAuth clients are returned. If the config store +// is unavailable or the VK lookup fails, only clients with AllowOnAllVirtualKeys=true are returned. +func (c *Config) GetPerUserOAuthMCPClientsForVirtualKey(ctx context.Context, virtualKeyID string) map[string]string { + all := c.GetPerUserOAuthMCPClients() + if virtualKeyID == "" { + return all + } + + // Build set of per-user OAuth clients that allow all virtual keys. + c.muMCP.RLock() + allowAll := make(map[string]string) + if c.MCPConfig != nil { + for _, client := range c.MCPConfig.ClientConfigs { + if client != nil && client.AuthType == schemas.MCPAuthTypePerUserOauth && client.AllowOnAllVirtualKeys { + allowAll[client.ID] = client.Name + } + } + } + c.muMCP.RUnlock() + + if c.ConfigStore == nil { + return allowAll + } + + // Get VK-specific MCP configs (with MCPClient preloaded so we have the string ClientID). + vkConfigs, err := c.ConfigStore.GetVirtualKeyMCPConfigs(ctx, virtualKeyID) + if err != nil { + // Fail closed: only return clients that are allowed on all virtual keys. + return allowAll + } + explicit := make(map[string]bool, len(vkConfigs)) + for _, cfg := range vkConfigs { + explicit[cfg.MCPClient.ClientID] = true + } + + result := make(map[string]string) + for clientID, clientName := range all { + if _, ok := allowAll[clientID]; ok || explicit[clientID] { + result[clientID] = clientName + } + } + return result +} + // GetPluginOrder returns the names of all base plugins in their sorted placement order. // This method is lock-free and safe for concurrent access from hot paths. // Do not modify the returned slice; it is a shared snapshot and must be treated read-only. @@ -2355,9 +2528,19 @@ type pluginChunkInterceptor struct { // Plugins are called in reverse order (same as PostHook) so modifications chain correctly. func (i *pluginChunkInterceptor) InterceptChunk(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, stream *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { for j := len(i.plugins) - 1; j >= 0; j-- { - modified, err := i.plugins[j].HTTPTransportStreamChunkHook(ctx, req, stream) + plugin := i.plugins[j] + pluginName := plugin.GetName() + var ( + modified *schemas.BifrostStreamChunk + err error + ) + func() { + pluginCtx := ctx.WithPluginScope(&pluginName) + defer pluginCtx.ReleasePluginScope() + modified, err = plugin.HTTPTransportStreamChunkHook(pluginCtx, req, stream) + }() if err != nil { - return modified, fmt.Errorf("failed to intercept chunk with plugin %s: %w", i.plugins[j].GetName(), err) + return modified, fmt.Errorf("failed to intercept chunk with plugin %s: %w", pluginName, err) } if modified == nil { return nil, nil // Plugin wants to skip this chunk @@ -2839,6 +3022,76 @@ func (c *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*con return config.Redacted(), nil } +// GetProviderKeysRaw retrieves the raw keys configured for a provider. +func (c *Config) GetProviderKeysRaw(provider schemas.ModelProvider) ([]schemas.Key, error) { + c.Mu.RLock() + defer c.Mu.RUnlock() + + config, exists := c.Providers[provider] + if !exists { + return nil, ErrNotFound + } + + keys := append([]schemas.Key(nil), config.Keys...) + return keys, nil +} + +// GetProviderKeysRedacted retrieves redacted keys configured for a provider. +func (c *Config) GetProviderKeysRedacted(provider schemas.ModelProvider) ([]schemas.Key, error) { + c.Mu.RLock() + defer c.Mu.RUnlock() + + config, exists := c.Providers[provider] + if !exists { + return nil, ErrNotFound + } + + return append([]schemas.Key(nil), config.Redacted().Keys...), nil +} + +// GetProviderKeyRaw retrieves a single raw key configured for a provider. +func (c *Config) GetProviderKeyRaw(provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { + c.Mu.RLock() + defer c.Mu.RUnlock() + + config, exists := c.Providers[provider] + if !exists { + return nil, ErrNotFound + } + + index := slices.IndexFunc(config.Keys, func(key schemas.Key) bool { + return key.ID == keyID + }) + if index == -1 { + return nil, ErrNotFound + } + + key := config.Keys[index] + return &key, nil +} + +// GetProviderKeyRedacted retrieves a single redacted key configured for a provider. +func (c *Config) GetProviderKeyRedacted(provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { + c.Mu.RLock() + defer c.Mu.RUnlock() + + config, exists := c.Providers[provider] + if !exists { + return nil, ErrNotFound + } + + redacted := config.Redacted() + index := slices.IndexFunc(redacted.Keys, func(key schemas.Key) bool { + return key.ID == keyID + }) + if index == -1 { + return nil, ErrNotFound + } + + key := redacted.Keys[index] + return &key, nil +} + // GetAllProviders returns all configured provider names. func (c *Config) GetAllProviders() ([]schemas.ModelProvider, error) { c.Mu.RLock() @@ -2992,6 +3245,162 @@ func (c *Config) UpdateProviderConfig(ctx context.Context, provider schemas.Mode return nil } +// AddProviderKey adds a new key to an existing provider configuration. +func (c *Config) AddProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key) error { + c.Mu.Lock() + defer c.Mu.Unlock() + + existingConfig, exists := c.Providers[provider] + if !exists { + return ErrNotFound + } + + if key.ID == "" { + key.ID = uuid.NewString() + } + + updatedConfig := existingConfig + updatedConfig.Keys = append(append([]schemas.Key(nil), existingConfig.Keys...), key) + + skipDBUpdate := false + if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil { + if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok { + skipDBUpdate = skip + } + } + if c.ConfigStore != nil && !skipDBUpdate { + if err := c.ConfigStore.CreateProviderKey(ctx, provider, key); err != nil { + if errors.Is(err, configstore.ErrNotFound) { + return ErrNotFound + } + return fmt.Errorf("failed to create provider key in store: %w", err) + } + } + + c.Providers[provider] = updatedConfig + + c.Mu.Unlock() + clientErr := c.client.UpdateProvider(provider) + c.Mu.Lock() + + if clientErr != nil { + if reflect.DeepEqual(c.Providers[provider], updatedConfig) { + c.Providers[provider] = existingConfig + } + return fmt.Errorf("failed to update provider: %w", clientErr) + } + + logger.Info("Added key %s to provider: %s", key.ID, provider) + return nil +} + +// UpdateProviderKey updates a single key on an existing provider configuration. +func (c *Config) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key) error { + c.Mu.Lock() + defer c.Mu.Unlock() + + existingConfig, exists := c.Providers[provider] + if !exists { + return ErrNotFound + } + + index := slices.IndexFunc(existingConfig.Keys, func(existingKey schemas.Key) bool { + return existingKey.ID == keyID + }) + if index == -1 { + return ErrNotFound + } + + updatedConfig := existingConfig + updatedConfig.Keys = append([]schemas.Key(nil), existingConfig.Keys...) + key.ID = keyID + updatedConfig.Keys[index] = key + + skipDBUpdate := false + if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil { + if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok { + skipDBUpdate = skip + } + } + if c.ConfigStore != nil && !skipDBUpdate { + if err := c.ConfigStore.UpdateProviderKey(ctx, provider, keyID, key); err != nil { + if errors.Is(err, configstore.ErrNotFound) { + return ErrNotFound + } + return fmt.Errorf("failed to update provider key in store: %w", err) + } + } + + c.Providers[provider] = updatedConfig + + c.Mu.Unlock() + clientErr := c.client.UpdateProvider(provider) + c.Mu.Lock() + + if clientErr != nil { + if reflect.DeepEqual(c.Providers[provider], updatedConfig) { + c.Providers[provider] = existingConfig + } + return fmt.Errorf("failed to update provider: %w", clientErr) + } + + logger.Info("Updated key %s for provider: %s", keyID, provider) + return nil +} + +// RemoveProviderKey removes a single key from an existing provider configuration. +func (c *Config) RemoveProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) error { + c.Mu.Lock() + defer c.Mu.Unlock() + + existingConfig, exists := c.Providers[provider] + if !exists { + return ErrNotFound + } + + index := slices.IndexFunc(existingConfig.Keys, func(existingKey schemas.Key) bool { + return existingKey.ID == keyID + }) + if index == -1 { + return ErrNotFound + } + + updatedConfig := existingConfig + updatedConfig.Keys = append([]schemas.Key(nil), existingConfig.Keys[:index]...) + updatedConfig.Keys = append(updatedConfig.Keys, existingConfig.Keys[index+1:]...) + + skipDBUpdate := false + if ctx.Value(schemas.BifrostContextKeySkipDBUpdate) != nil { + if skip, ok := ctx.Value(schemas.BifrostContextKeySkipDBUpdate).(bool); ok { + skipDBUpdate = skip + } + } + if c.ConfigStore != nil && !skipDBUpdate { + if err := c.ConfigStore.DeleteProviderKey(ctx, provider, keyID); err != nil { + if errors.Is(err, configstore.ErrNotFound) { + return ErrNotFound + } + return fmt.Errorf("failed to delete provider key from store: %w", err) + } + } + + c.Providers[provider] = updatedConfig + + c.Mu.Unlock() + clientErr := c.client.UpdateProvider(provider) + c.Mu.Lock() + + if clientErr != nil { + if reflect.DeepEqual(c.Providers[provider], updatedConfig) { + c.Providers[provider] = existingConfig + } + return fmt.Errorf("failed to update provider: %w", clientErr) + } + + logger.Info("Removed key %s from provider: %s", keyID, provider) + return nil +} + // RemoveProvider removes a provider configuration from memory. func (c *Config) RemoveProvider(ctx context.Context, provider schemas.ModelProvider) error { c.Mu.Lock() @@ -3032,16 +3441,63 @@ func (c *Config) GetAllKeys() ([]configstoreTables.TableKey, error) { if blacklisted == nil { blacklisted = []string{} } - keys = append(keys, configstoreTables.TableKey{ + configStoreKey := configstoreTables.TableKey{ KeyID: key.ID, Name: key.Name, - Value: *schemas.NewEnvVar(""), + Value: *key.Value.Redacted(), Models: models, BlacklistedModels: blacklisted, Weight: bifrost.Ptr(key.Weight), Provider: string(providerKey), ConfigHash: key.ConfigHash, - }) + } + if key.AzureKeyConfig != nil { + cfg := *key.AzureKeyConfig // safe copy + cfg.Endpoint = *cfg.Endpoint.Redacted() + cfg.ClientID = cfg.ClientID.Redacted() + cfg.ClientSecret = cfg.ClientSecret.Redacted() + cfg.TenantID = cfg.TenantID.Redacted() + configStoreKey.AzureKeyConfig = &cfg + } + if key.BedrockKeyConfig != nil { + cfg := *key.BedrockKeyConfig // safe copy + cfg.ARN = key.BedrockKeyConfig.ARN.Redacted() + cfg.AccessKey = *cfg.AccessKey.Redacted() + cfg.ExternalID = cfg.ExternalID.Redacted() + cfg.Region = cfg.Region.Redacted() + cfg.RoleARN = cfg.RoleARN.Redacted() + cfg.RoleSessionName = cfg.RoleSessionName.Redacted() + cfg.SecretKey = *cfg.SecretKey.Redacted() + cfg.SessionToken = cfg.SessionToken.Redacted() + configStoreKey.BedrockKeyConfig = &cfg + } + if key.VertexKeyConfig != nil { + cfg := *key.VertexKeyConfig // safe copy + cfg.ProjectID = *cfg.ProjectID.Redacted() + cfg.ProjectNumber = *cfg.ProjectNumber.Redacted() + cfg.Region = *cfg.Region.Redacted() + cfg.AuthCredentials = *cfg.AuthCredentials.Redacted() + configStoreKey.VertexKeyConfig = &cfg + } + if key.ReplicateKeyConfig != nil { + configStoreKey.ReplicateKeyConfig = key.ReplicateKeyConfig + } + if key.VLLMKeyConfig != nil { + cfg := *key.VLLMKeyConfig // safe copy + cfg.URL = *cfg.URL.Redacted() + configStoreKey.VLLMKeyConfig = &cfg + } + if key.OllamaKeyConfig != nil { + cfg := *key.OllamaKeyConfig // safe copy + cfg.URL = *cfg.URL.Redacted() + configStoreKey.OllamaKeyConfig = &cfg + } + if key.SGLKeyConfig != nil { + cfg := *key.SGLKeyConfig // safe copy + cfg.URL = *cfg.URL.Redacted() + configStoreKey.SGLKeyConfig = &cfg + } + keys = append(keys, configStoreKey) } } @@ -3201,9 +3657,11 @@ func (c *Config) UpdateMCPClient(ctx context.Context, id string, updatedConfig * c.MCPConfig.ClientConfigs[configIndex].Headers = updatedConfig.Headers c.MCPConfig.ClientConfigs[configIndex].ToolsToExecute = updatedConfig.ToolsToExecute c.MCPConfig.ClientConfigs[configIndex].ToolsToAutoExecute = updatedConfig.ToolsToAutoExecute + c.MCPConfig.ClientConfigs[configIndex].AllowedExtraHeaders = updatedConfig.AllowedExtraHeaders c.MCPConfig.ClientConfigs[configIndex].ToolPricing = updatedConfig.ToolPricing c.MCPConfig.ClientConfigs[configIndex].IsPingAvailable = updatedConfig.IsPingAvailable c.MCPConfig.ClientConfigs[configIndex].ToolSyncInterval = updatedConfig.ToolSyncInterval + c.MCPConfig.ClientConfigs[configIndex].AllowOnAllVirtualKeys = updatedConfig.AllowOnAllVirtualKeys return nil } @@ -3302,7 +3760,7 @@ func (c *Config) autoDetectProviders(ctx context.Context) { ID: keyID, Name: fmt.Sprintf("%s_auto_detected", envVar), Value: *schemas.NewEnvVar(apiKey), - Models: []string{}, // Empty means all supported models + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, @@ -3599,14 +4057,3 @@ func DeepCopy[T any](in T) (T, error) { err = sonic.Unmarshal(b, &out) return out, err } - -func applyProviderPricingOverrides(catalog *modelcatalog.ModelCatalog, providers map[schemas.ModelProvider]configstore.ProviderConfig) { - if catalog == nil { - return - } - for provider, providerConfig := range providers { - if err := catalog.SetProviderPricingOverrides(provider, providerConfig.PricingOverrides); err != nil { - logger.Warn("failed to load pricing overrides for provider %s: %v", provider, err) - } - } -} diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index fffd287c6e..0723aeb2ff 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -240,8 +240,7 @@ End-to-end tests for virtual key provider configuration operations. | TestSQLite_VKProviderConfig_KeyReference | VK provider config key references work | | TestSQLite_VKProviderConfig_HashChangesOnKeyIDChange | Hash changes when key ID changes | | TestSQLite_VKProviderConfig_WeightAndAllowedModels | Weight and allowed models handled correctly | -| TestSQLite_VKProviderConfig_BudgetAndRateLimit | BudgetID/RateLimitID persisted correctly | -| TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit | VK hash includes provider config budget/rate limit | +| TestGenerateVirtualKeyHash_ProviderConfigRateLimit | VK hash includes provider config rate limit | =================================================================================== SQLITE INTEGRATION TESTS - VK MCP CONFIGS @@ -449,6 +448,68 @@ func (m *MockConfigStore) DeleteProvider(ctx context.Context, provider schemas.M return nil } +func (m *MockConfigStore) GetProviderKeys(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + config, ok := m.providers[provider] + if !ok { + return nil, configstore.ErrNotFound + } + return append([]schemas.Key(nil), config.Keys...), nil +} + +func (m *MockConfigStore) GetProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string) (*schemas.Key, error) { + config, ok := m.providers[provider] + if !ok { + return nil, configstore.ErrNotFound + } + for _, key := range config.Keys { + if key.ID == keyID { + keyCopy := key + return &keyCopy, nil + } + } + return nil, configstore.ErrNotFound +} + +func (m *MockConfigStore) CreateProviderKey(ctx context.Context, provider schemas.ModelProvider, key schemas.Key, tx ...*gorm.DB) error { + config, ok := m.providers[provider] + if !ok { + return configstore.ErrNotFound + } + config.Keys = append(config.Keys, key) + m.providers[provider] = config + return nil +} + +func (m *MockConfigStore) UpdateProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, key schemas.Key, tx ...*gorm.DB) error { + config, ok := m.providers[provider] + if !ok { + return configstore.ErrNotFound + } + for i := range config.Keys { + if config.Keys[i].ID == keyID { + config.Keys[i] = key + m.providers[provider] = config + return nil + } + } + return configstore.ErrNotFound +} + +func (m *MockConfigStore) DeleteProviderKey(ctx context.Context, provider schemas.ModelProvider, keyID string, tx ...*gorm.DB) error { + config, ok := m.providers[provider] + if !ok { + return configstore.ErrNotFound + } + for i := range config.Keys { + if config.Keys[i].ID == keyID { + config.Keys = append(config.Keys[:i], config.Keys[i+1:]...) + m.providers[provider] = config + return nil + } + } + return configstore.ErrNotFound +} + // MCP config func (m *MockConfigStore) GetMCPConfig(ctx context.Context) (*schemas.MCPConfig, error) { return m.mcpConfig, nil @@ -489,30 +550,32 @@ func (m *MockConfigStore) UpdateMCPClientConfig(ctx context.Context, id string, if m.mcpConfig.ClientConfigs[i].ID == id { // Found the entry, update it with the new config m.mcpConfig.ClientConfigs[i] = &schemas.MCPClientConfig{ - ID: clientConfig.ClientID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - Headers: clientConfig.Headers, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + ID: clientConfig.ClientID, + Name: clientConfig.Name, + IsCodeModeClient: clientConfig.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), + ConnectionString: clientConfig.ConnectionString, + StdioConfig: clientConfig.StdioConfig, + Headers: clientConfig.Headers, + ToolsToExecute: clientConfig.ToolsToExecute, + ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + AllowedExtraHeaders: clientConfig.AllowedExtraHeaders, } return nil } } // If not found, create a new entry (similar to CreateMCPClientConfig behavior) m.mcpConfig.ClientConfigs = append(m.mcpConfig.ClientConfigs, &schemas.MCPClientConfig{ - ID: clientConfig.ClientID, - Name: clientConfig.Name, - IsCodeModeClient: clientConfig.IsCodeModeClient, - ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), - ConnectionString: clientConfig.ConnectionString, - StdioConfig: clientConfig.StdioConfig, - Headers: clientConfig.Headers, - ToolsToExecute: clientConfig.ToolsToExecute, - ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + ID: clientConfig.ClientID, + Name: clientConfig.Name, + IsCodeModeClient: clientConfig.IsCodeModeClient, + ConnectionType: schemas.MCPConnectionType(clientConfig.ConnectionType), + ConnectionString: clientConfig.ConnectionString, + StdioConfig: clientConfig.StdioConfig, + Headers: clientConfig.Headers, + ToolsToExecute: clientConfig.ToolsToExecute, + ToolsToAutoExecute: clientConfig.ToolsToAutoExecute, + AllowedExtraHeaders: clientConfig.AllowedExtraHeaders, }) return nil @@ -522,6 +585,10 @@ func (m *MockConfigStore) GetMCPClientsPaginated(ctx context.Context, params con return nil, 0, nil } +func (m *MockConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) error { + return nil +} + func (m *MockConfigStore) DeleteMCPClientConfig(ctx context.Context, id string) error { return nil } @@ -704,6 +771,14 @@ func (m *MockConfigStore) GetVirtualKeyByValue(ctx context.Context, value string return nil, nil } +func (m *MockConfigStore) GetVirtualKeyMCPConfigsByMCPClientID(ctx context.Context, mcpClientID uint) ([]tables.TableVirtualKeyMCPConfig, error) { + return nil, nil +} + +func (m *MockConfigStore) GetVirtualKeyMCPConfigsByMCPClientIDs(ctx context.Context, mcpClientIDs []uint) ([]tables.TableVirtualKeyMCPConfig, error) { + return nil, nil +} + // Virtual key provider config func (m *MockConfigStore) GetVirtualKeyProviderConfigs(ctx context.Context, virtualKeyID string) ([]tables.TableVirtualKeyProviderConfig, error) { return nil, nil @@ -853,6 +928,30 @@ func (m *MockConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) return nil } +func (m *MockConfigStore) GetPricingOverrides(ctx context.Context, filter configstore.PricingOverrideFilters) ([]tables.TablePricingOverride, error) { + return []tables.TablePricingOverride{}, nil +} + +func (m *MockConfigStore) GetPricingOverridesPaginated(ctx context.Context, params configstore.PricingOverridesQueryParams) ([]tables.TablePricingOverride, int64, error) { + return []tables.TablePricingOverride{}, 0, nil +} + +func (m *MockConfigStore) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) { + return nil, configstore.ErrNotFound +} + +func (m *MockConfigStore) CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error { + return nil +} + +func (m *MockConfigStore) UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error { + return nil +} + +func (m *MockConfigStore) DeletePricingOverride(ctx context.Context, id string, tx ...*gorm.DB) error { + return nil +} + // Model parameters func (m *MockConfigStore) GetModelParameters(ctx context.Context, model string) (*tables.TableModelParameters, error) { return nil, nil @@ -1022,6 +1121,106 @@ func (m *MockConfigStore) DeleteOauthToken(ctx context.Context, id string) error return nil } +// Per-user OAuth session CRUD +func (m *MockConfigStore) GetOauthUserSessionByID(ctx context.Context, id string) (*tables.TableOauthUserSession, error) { + return nil, nil +} +func (m *MockConfigStore) GetOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { + return nil, nil +} +func (m *MockConfigStore) ClaimOauthUserSessionByState(ctx context.Context, state string) (*tables.TableOauthUserSession, error) { + return nil, nil +} +func (m *MockConfigStore) GetOauthUserSessionBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserSession, error) { + return nil, nil +} +func (m *MockConfigStore) CreateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { + return nil +} +func (m *MockConfigStore) UpdateOauthUserSession(ctx context.Context, session *tables.TableOauthUserSession) error { + return nil +} + +// Per-user OAuth token CRUD +func (m *MockConfigStore) GetOauthUserTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (*tables.TableOauthUserToken, error) { + return nil, nil +} +func (m *MockConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, sessionToken string) (*tables.TableOauthUserToken, error) { + return nil, nil +} +func (m *MockConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { + return nil +} +func (m *MockConfigStore) UpdateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { + return nil +} +func (m *MockConfigStore) DeleteOauthUserToken(ctx context.Context, id string) error { + return nil +} +func (m *MockConfigStore) DeleteOauthUserTokensByMCPClient(ctx context.Context, mcpClientID string) error { + return nil +} + +// Per-user OAuth Authorization Server CRUD +func (m *MockConfigStore) GetPerUserOAuthClientByClientID(ctx context.Context, clientID string) (*tables.TablePerUserOAuthClient, error) { + return nil, nil +} +func (m *MockConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *tables.TablePerUserOAuthClient) error { + return nil +} +func (m *MockConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { + return nil, nil +} +func (m *MockConfigStore) GetPerUserOAuthSessionByID(ctx context.Context, id string) (*tables.TablePerUserOAuthSession, error) { + return nil, nil +} +func (m *MockConfigStore) CreatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { + return nil +} +func (m *MockConfigStore) UpdatePerUserOAuthSession(ctx context.Context, session *tables.TablePerUserOAuthSession) error { + return nil +} +func (m *MockConfigStore) DeletePerUserOAuthSession(ctx context.Context, id string) error { + return nil +} +func (m *MockConfigStore) GetPerUserOAuthCodeByCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { + return nil, nil +} +func (m *MockConfigStore) ClaimPerUserOAuthCode(ctx context.Context, code string) (*tables.TablePerUserOAuthCode, error) { + return nil, nil +} +func (m *MockConfigStore) CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { + return nil +} +func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error { + return nil +} + +func (m *MockConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { + return nil, nil +} +func (m *MockConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + return nil +} +func (m *MockConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + return nil +} +func (m *MockConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { + return nil +} +func (m *MockConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { + return 1, nil +} +func (m *MockConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { + return nil, nil +} +func (m *MockConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { + return nil +} +func (m *MockConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { + return 1, nil +} + // Routing rules func (m *MockConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error) { return nil, nil @@ -1089,6 +1288,9 @@ func (m *MockConfigStore) DeletePrompt(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { return nil, nil } +func (m *MockConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { + return nil, nil +} func (m *MockConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { return nil, nil } @@ -2163,9 +2365,8 @@ func TestGenerateKeyHash(t *testing.T) { Models: []string{"gpt-4", "gpt-3.5-turbo"}, Weight: 1.5, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://my-azure.openai.azure.com"), - Deployments: map[string]string{"gpt-4": "gpt-4-deployment"}, - APIVersion: schemas.NewEnvVar(apiVersion), + Endpoint: *schemas.NewEnvVar("https://my-azure.openai.azure.com"), + APIVersion: schemas.NewEnvVar(apiVersion), }, } @@ -2186,12 +2387,30 @@ func TestGenerateKeyHash(t *testing.T) { Models: []string{"gpt-4", "gpt-3.5-turbo"}, Weight: 1.5, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://different-azure.openai.azure.com"), // Different endpoint - Deployments: map[string]string{"gpt-4": "gpt-4-deployment"}, - APIVersion: schemas.NewEnvVar(apiVersion), + Endpoint: *schemas.NewEnvVar("https://different-azure.openai.azure.com"), // Different endpoint + APIVersion: schemas.NewEnvVar(apiVersion), }, } + // Aliases alone should produce different hash + keyWithAliases := schemas.Key{ + ID: "key-1", + Name: "test-key", + Value: *schemas.NewEnvVar("sk-123"), + Models: []string{"gpt-4", "gpt-3.5-turbo"}, + Weight: 1.5, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, + } + + hashWithAliases, err := configstore.GenerateKeyHash(keyWithAliases) + if err != nil { + t.Fatalf("Failed to generate hash: %v", err) + } + + if hash1 == hashWithAliases { + t.Error("Expected different hash for keys with Aliases") + } + hash6b, err := configstore.GenerateKeyHash(key6b) if err != nil { t.Fatalf("Failed to generate hash: %v", err) @@ -4648,12 +4867,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4662,12 +4879,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4687,12 +4902,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4701,12 +4914,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://different-azure.openai.azure.com"), // Changed! APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4726,12 +4937,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4740,12 +4949,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-10-21"), // Changed! - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4765,11 +4972,9 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4778,12 +4983,9 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-3.5-turbo": "gpt-35-turbo-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-3.5-turbo": "gpt-35-turbo-deployment", // Added! - }, }, } @@ -4816,9 +5018,6 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4842,9 +5041,6 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4873,12 +5069,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), // APIVersion is nil (will use default) - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4887,12 +5081,10 @@ func TestKeyHashComparison_AzureConfigSyncScenarios(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), // Explicitly set - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -4915,13 +5107,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -4930,13 +5120,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -4956,13 +5144,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -4971,13 +5157,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAI44QH8DHBEXAMPLE"), // Changed! SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -4997,13 +5181,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5012,13 +5194,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("differentSecretKey/NEWKEY/bPxRfiCYEXAMPLEKEY"), // Changed! Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5038,13 +5218,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5053,13 +5231,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), // Changed! - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5079,14 +5255,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), ARN: schemas.NewEnvVar("arn:aws:bedrock:us-east-1:123456789012:inference-profile/old-profile"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5095,14 +5269,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), ARN: schemas.NewEnvVar("arn:aws:bedrock:us-east-1:123456789012:inference-profile/new-profile"), // Changed! - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5122,13 +5294,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5137,14 +5307,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile", "claude-3.5": "claude-35-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - "claude-3.5": "claude-35-inference-profile", // Added! - }, }, } @@ -5178,9 +5345,6 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5205,9 +5369,6 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5236,14 +5397,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), // SessionToken is nil - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5252,14 +5411,12 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), SessionToken: schemas.NewEnvVar("AQoDYXdzEJr..."), // Explicitly set - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5280,13 +5437,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar(""), // Empty for IAM role auth SecretKey: *schemas.NewEnvVar(""), // Empty for IAM role auth Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5296,13 +5451,11 @@ func TestKeyHashComparison_BedrockConfigSyncScenarios(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "claude-3-inference-profile"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "claude-3-inference-profile", - }, }, } @@ -5324,12 +5477,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-initial"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -5361,12 +5512,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-dashboard-edited"), // Changed via dashboard! Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -5396,12 +5545,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-initial"), // Original value from file Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, }, }, @@ -5435,13 +5582,10 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-initial"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://new-azure.openai.azure.com"), // Changed! APIVersion: schemas.NewEnvVar("2024-10-21"), // Changed! - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", // Added! - }, }, }, }, @@ -5533,8 +5677,8 @@ func TestProviderHashComparison_AzureProviderFullLifecycle(t *testing.T) { if finalConfig.Keys[0].AzureKeyConfig.APIVersion.GetValue() != "2024-10-21" { t.Errorf("Expected updated APIVersion, got %s", finalConfig.Keys[0].AzureKeyConfig.APIVersion.GetValue()) } - if len(finalConfig.Keys[0].AzureKeyConfig.Deployments) != 2 { - t.Errorf("Expected 2 deployments, got %d", len(finalConfig.Keys[0].AzureKeyConfig.Deployments)) + if len(finalConfig.Keys[0].Aliases) != 2 { + t.Errorf("Expected 2 deployments, got %d", len(finalConfig.Keys[0].Aliases)) } t.Log("Step 5 - Final state verified, Azure provider lifecycle complete βœ“") @@ -5548,13 +5692,11 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), // Empty for Bedrock with IAM or AccessKey auth Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, } @@ -5585,13 +5727,11 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key-eu", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAI44QH8DHBEXAMPLE"), SecretKey: *schemas.NewEnvVar("je7MtGbClwBF/2Zp9Utk/h3yCo8nvbEXAMPLEKEY"), Region: schemas.NewEnvVar("eu-west-1"), // Different region - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, } @@ -5614,13 +5754,11 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, }, }, @@ -5655,15 +5793,12 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), // Changed! ARN: schemas.NewEnvVar("arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile"), // Added! - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", // Added! - }, }, }, }, @@ -5767,8 +5902,8 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { if fileKey.BedrockKeyConfig.ARN == nil || fileKey.BedrockKeyConfig.ARN.GetValue() != "arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile" { t.Error("Expected ARN to be set") } - if len(fileKey.BedrockKeyConfig.Deployments) != 2 { - t.Errorf("Expected 2 deployments, got %d", len(fileKey.BedrockKeyConfig.Deployments)) + if len(fileKey.Aliases) != 2 { + t.Errorf("Expected 2 deployments, got %d", len(fileKey.Aliases)) } // Verify dashboard-added key is preserved @@ -5786,15 +5921,12 @@ func TestProviderHashComparison_BedrockProviderFullLifecycle(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), ARN: schemas.NewEnvVar("arn:aws:bedrock:us-west-2:123456789012:inference-profile/my-profile"), - Deployments: map[string]string{ - "claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", - "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", - }, }, }, }, @@ -5832,12 +5964,10 @@ func TestProviderHashComparison_AzureNewProviderFromConfig(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, }, }, @@ -5901,13 +6031,11 @@ func TestProviderHashComparison_BedrockNewProviderFromConfig(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, }, }, @@ -5972,12 +6100,10 @@ func TestProviderHashComparison_AzureDBValuePreservedWhenHashMatches(t *testing. Name: "azure-openai-key", Value: *schemas.NewEnvVar("DASHBOARD-EDITED-SECRET-KEY"), // Dashboard edited this! Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), APIVersion: schemas.NewEnvVar("2024-02-01"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, }, }, @@ -6003,12 +6129,10 @@ func TestProviderHashComparison_AzureDBValuePreservedWhenHashMatches(t *testing. Name: "azure-openai-key", Value: *schemas.NewEnvVar("original-key-from-file"), // Different value than DB! Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), // Same APIVersion: schemas.NewEnvVar("2024-02-01"), // Same - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", // Same - }, }, }, }, @@ -6062,13 +6186,11 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("DASHBOARD-EDITED-ACCESS-KEY"), // Dashboard edited! SecretKey: *schemas.NewEnvVar("DASHBOARD-EDITED-SECRET-KEY"), // Dashboard edited! Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", - }, }, }, }, @@ -6094,13 +6216,11 @@ func TestProviderHashComparison_BedrockDBValuePreservedWhenHashMatches(t *testin Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "anthropic.claude-3-sonnet-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), // Different! SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), // Different! Region: schemas.NewEnvVar("us-east-1"), // Same - Deployments: map[string]string{ - "claude-3": "anthropic.claude-3-sonnet-20240229-v1:0", // Same - }, }, }, }, @@ -6185,12 +6305,10 @@ func TestProviderHashComparison_AzureConfigChangedInFile(t *testing.T) { Name: "azure-openai-key", Value: *schemas.NewEnvVar("azure-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://NEW-azure.openai.azure.com"), // Changed! APIVersion: schemas.NewEnvVar("2024-10-21"), // Changed! - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", // Added! - }, }, }, }, @@ -6275,14 +6393,12 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { Name: "aws-bedrock-key", Value: *schemas.NewEnvVar(""), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), Region: schemas.NewEnvVar("us-west-2"), // Changed! ARN: schemas.NewEnvVar("arn:aws:bedrock:us-west-2:123456789012:inference-profile/new-profile"), // Added! - Deployments: map[string]string{ - "claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0", // Added! - }, }, }, }, @@ -6337,7 +6453,6 @@ func TestProviderHashComparison_BedrockConfigChangedInFile(t *testing.T) { func TestGenerateVirtualKeyHash(t *testing.T) { // Create a virtual key teamID := "team-1" - budgetID := "budget-1" vk1 := tables.TableVirtualKey{ ID: "vk-1", Name: "test-vk", @@ -6345,7 +6460,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } // Generate hash @@ -6366,7 +6480,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } hash2, err := configstore.GenerateVirtualKeyHash(vk2) @@ -6386,7 +6499,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } hash3, err := configstore.GenerateVirtualKeyHash(vk3) @@ -6406,7 +6518,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_different", // Different value IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } hash4, err := configstore.GenerateVirtualKeyHash(vk4) @@ -6426,7 +6537,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: false, // Different IsActive TeamID: &teamID, - BudgetID: &budgetID, } hash5, err := configstore.GenerateVirtualKeyHash(vk5) @@ -6447,7 +6557,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &differentTeamID, // Different TeamID - BudgetID: &budgetID, } hash6, err := configstore.GenerateVirtualKeyHash(vk6) @@ -6467,7 +6576,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } hash7, err := configstore.GenerateVirtualKeyHash(vk7) @@ -6488,7 +6596,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, CustomerID: &customerID, // CustomerID set } @@ -6510,7 +6617,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, CustomerID: &differentCustomerID, // Different CustomerID } @@ -6523,27 +6629,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { t.Error("Expected different hash for virtual keys with different CustomerID values") } - // Different BudgetID should produce different hash - differentBudgetID := "budget-2" - vk9 := tables.TableVirtualKey{ - ID: "vk-1", - Name: "test-vk", - Description: "Test virtual key", - Value: "vk_abc123", - IsActive: true, - TeamID: &teamID, - BudgetID: &differentBudgetID, // Different BudgetID - } - - hash9, err := configstore.GenerateVirtualKeyHash(vk9) - if err != nil { - t.Fatalf("Failed to generate hash: %v", err) - } - - if hash1 == hash9 { - t.Error("Expected different hash for virtual keys with different BudgetID") - } - // RateLimitID should produce different hash rateLimitID := "ratelimit-1" vk10 := tables.TableVirtualKey{ @@ -6553,7 +6638,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, RateLimitID: &rateLimitID, // RateLimitID set } @@ -6575,7 +6659,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, RateLimitID: &differentRateLimitID, // Different RateLimitID } @@ -6593,7 +6676,6 @@ func TestGenerateVirtualKeyHash(t *testing.T) { // TestGenerateVirtualKeyHash_WithProviderConfigs tests hash generation with provider configs func TestGenerateVirtualKeyHash_WithProviderConfigs(t *testing.T) { - budgetID := "budget-pc-1" rateLimitID := "rl-pc-1" // Virtual key with provider configs @@ -6610,7 +6692,6 @@ func TestGenerateVirtualKeyHash_WithProviderConfigs(t *testing.T) { Provider: "openai", Weight: ptrFloat64(1.0), AllowedModels: []string{"gpt-4", "gpt-3.5-turbo"}, - BudgetID: &budgetID, RateLimitID: &rateLimitID, Keys: []tables.TableKey{ {KeyID: "key-1", Name: "key-1"}, @@ -6643,7 +6724,6 @@ func TestGenerateVirtualKeyHash_WithProviderConfigs(t *testing.T) { Provider: "anthropic", // Different provider Weight: ptrFloat64(1.0), AllowedModels: []string{"claude-3"}, - BudgetID: &budgetID, RateLimitID: &rateLimitID, }, }, @@ -6672,7 +6752,6 @@ func TestGenerateVirtualKeyHash_WithProviderConfigs(t *testing.T) { Provider: "openai", Weight: ptrFloat64(2.0), // Different weight AllowedModels: []string{"gpt-4", "gpt-3.5-turbo"}, - BudgetID: &budgetID, RateLimitID: &rateLimitID, Keys: []tables.TableKey{ {KeyID: "key-1", Name: "key-1"}, @@ -6776,7 +6855,6 @@ func TestGenerateVirtualKeyHash_WithMCPConfigs(t *testing.T) { // TestVirtualKeyHashComparison_MatchingHash tests that DB config is kept when hashes match func TestVirtualKeyHashComparison_MatchingHash(t *testing.T) { teamID := "team-1" - budgetID := "budget-1" // Create a virtual key (simulating what's in config.json) fileVK := tables.TableVirtualKey{ @@ -6786,7 +6864,6 @@ func TestVirtualKeyHashComparison_MatchingHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } // Generate file hash @@ -6797,7 +6874,6 @@ func TestVirtualKeyHashComparison_MatchingHash(t *testing.T) { // Create DB virtual key with same content (simulating existing DB record) dbTeamID := "team-1" - dbBudgetID := "budget-1" dbVK := tables.TableVirtualKey{ ID: "vk-1", Name: "test-vk", @@ -6805,7 +6881,6 @@ func TestVirtualKeyHashComparison_MatchingHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &dbTeamID, - BudgetID: &dbBudgetID, ConfigHash: fileHash, // Same hash as file } @@ -6830,7 +6905,6 @@ func TestVirtualKeyHashComparison_MatchingHash(t *testing.T) { // TestVirtualKeyHashComparison_DifferentHash tests that file config is used when hashes differ func TestVirtualKeyHashComparison_DifferentHash(t *testing.T) { teamID := "team-1" - budgetID := "budget-1" // Create DB virtual key with old config dbVK := tables.TableVirtualKey{ @@ -6840,7 +6914,6 @@ func TestVirtualKeyHashComparison_DifferentHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } dbHash, err := configstore.GenerateVirtualKeyHash(dbVK) @@ -6851,7 +6924,6 @@ func TestVirtualKeyHashComparison_DifferentHash(t *testing.T) { // Create file virtual key with updated config fileTeamID := "team-1" - fileBudgetID := "budget-1" fileVK := tables.TableVirtualKey{ ID: "vk-1", Name: "new-name", // Updated name @@ -6859,7 +6931,6 @@ func TestVirtualKeyHashComparison_DifferentHash(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &fileTeamID, - BudgetID: &fileBudgetID, } fileHash, err := configstore.GenerateVirtualKeyHash(fileVK) @@ -7031,26 +7102,6 @@ func TestVirtualKeyHashComparison_OptionalFieldsPresence(t *testing.T) { t.Error("Expected different hash for team_id vs customer_id") } - // Virtual key with budget_id - budgetID := "budget-1" - vkWithBudget := tables.TableVirtualKey{ - ID: "vk-1", - Name: "test-vk", - Description: "", - Value: "vk_abc123", - IsActive: true, - BudgetID: &budgetID, - } - - hashWithBudget, err := configstore.GenerateVirtualKeyHash(vkWithBudget) - if err != nil { - t.Fatalf("Failed to generate hash: %v", err) - } - - if hashNoOptional == hashWithBudget { - t.Error("Expected different hash when budget_id is added") - } - // Virtual key with rate_limit_id rateLimitID := "rl-1" vkWithRateLimit := tables.TableVirtualKey{ @@ -7077,7 +7128,6 @@ func TestVirtualKeyHashComparison_OptionalFieldsPresence(t *testing.T) { // TestVirtualKeyHashComparison_FieldValueChanges tests hash changes when field values change func TestVirtualKeyHashComparison_FieldValueChanges(t *testing.T) { teamID := "team-1" - budgetID := "budget-1" // Base virtual key baseVK := tables.TableVirtualKey{ @@ -7087,7 +7137,6 @@ func TestVirtualKeyHashComparison_FieldValueChanges(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, } baseHash, err := configstore.GenerateVirtualKeyHash(baseVK) @@ -7135,27 +7184,12 @@ func TestVirtualKeyHashComparison_FieldValueChanges(t *testing.T) { t.Error("Expected different hash when TeamID value changes") } - // Change BudgetID value - newBudgetID := "budget-2" - vkChangedBudget := baseVK - vkChangedBudget.BudgetID = &newBudgetID - - hashChangedBudget, err := configstore.GenerateVirtualKeyHash(vkChangedBudget) - if err != nil { - t.Fatalf("Failed to generate hash: %v", err) - } - - if baseHash == hashChangedBudget { - t.Error("Expected different hash when BudgetID value changes") - } - t.Log("βœ“ Field value changes correctly detected in hash") } // TestVirtualKeyHashComparison_RoundTrip tests JSON β†’ DB β†’ same JSON produces no changes func TestVirtualKeyHashComparison_RoundTrip(t *testing.T) { teamID := "team-1" - budgetID := "budget-1" rateLimitID := "rl-1" // Original config.json virtual key @@ -7166,7 +7200,6 @@ func TestVirtualKeyHashComparison_RoundTrip(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &teamID, - BudgetID: &budgetID, RateLimitID: &rateLimitID, ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ { @@ -7189,7 +7222,6 @@ func TestVirtualKeyHashComparison_RoundTrip(t *testing.T) { // Same config.json on reload (simulating app restart) reloadTeamID := "team-1" - reloadBudgetID := "budget-1" reloadRateLimitID := "rl-1" reloadVK := tables.TableVirtualKey{ ID: "vk-1", @@ -7198,7 +7230,6 @@ func TestVirtualKeyHashComparison_RoundTrip(t *testing.T) { Value: "vk_abc123", IsActive: true, TeamID: &reloadTeamID, - BudgetID: &reloadBudgetID, RateLimitID: &reloadRateLimitID, ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ { @@ -8710,7 +8741,6 @@ func TestSQLite_FullLifecycle_InitialLoad(t *testing.T) { Description: "Test virtual key 1", Value: "vk_test123", IsActive: true, - BudgetID: &budgetID, RateLimitID: &rateLimitID, ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ { @@ -10425,9 +10455,6 @@ func TestGenerateKeyHash_StableOrdering(t *testing.T) { // TestGenerateVirtualKeyHash_StableProviderConfigOrdering verifies hash stability with different provider config orderings func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { - budgetID1 := "budget-1" - budgetID2 := "budget-2" - // VK with provider configs in order A vkOrderA := tables.TableVirtualKey{ ID: "vk-1", @@ -10442,7 +10469,6 @@ func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { Provider: "openai", Weight: ptrFloat64(1.0), AllowedModels: []string{"gpt-4"}, - BudgetID: &budgetID1, }, { ID: 2, @@ -10450,7 +10476,6 @@ func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { Provider: "anthropic", Weight: ptrFloat64(2.0), AllowedModels: []string{"claude-3"}, - BudgetID: &budgetID2, }, { ID: 3, @@ -10483,7 +10508,6 @@ func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { Provider: "anthropic", Weight: ptrFloat64(2.0), AllowedModels: []string{"claude-3"}, - BudgetID: &budgetID2, }, { ID: 1, @@ -10491,7 +10515,6 @@ func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { Provider: "openai", Weight: ptrFloat64(1.0), AllowedModels: []string{"gpt-4"}, - BudgetID: &budgetID1, }, }, } @@ -10510,7 +10533,6 @@ func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { Provider: "anthropic", Weight: ptrFloat64(2.0), AllowedModels: []string{"claude-3"}, - BudgetID: &budgetID2, }, { ID: 1, @@ -10518,7 +10540,6 @@ func TestGenerateVirtualKeyHash_StableProviderConfigOrdering(t *testing.T) { Provider: "openai", Weight: ptrFloat64(1.0), AllowedModels: []string{"gpt-4"}, - BudgetID: &budgetID1, }, { ID: 3, @@ -10931,8 +10952,6 @@ func TestGenerateVirtualKeyHash_StableToolsToExecuteOrdering(t *testing.T) { // TestGenerateVirtualKeyHash_StableCombinedOrdering verifies hash stability with all nested orderings randomized func TestGenerateVirtualKeyHash_StableCombinedOrdering(t *testing.T) { - budgetID := "budget-1" - // VK with all elements in order A vkOrderA := tables.TableVirtualKey{ ID: "vk-1", @@ -10940,7 +10959,6 @@ func TestGenerateVirtualKeyHash_StableCombinedOrdering(t *testing.T) { Description: "Test virtual key", Value: "vk_abc123", IsActive: true, - BudgetID: &budgetID, ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ { ID: 1, @@ -10983,7 +11001,6 @@ func TestGenerateVirtualKeyHash_StableCombinedOrdering(t *testing.T) { Description: "Test virtual key", Value: "vk_abc123", IsActive: true, - BudgetID: &budgetID, ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ { ID: 2, @@ -12581,6 +12598,91 @@ func TestSQLite_Governance_DBOnly_AllPreserved(t *testing.T) { t.Log("βœ“ All dashboard-added entities preserved on reload") } +// TestSQLite_Governance_PricingOverrides_Reconciliation tests that pricing overrides +// defined in config.json are properly reconciled on reload (create, update, preserve). +func TestSQLite_Governance_PricingOverrides_Reconciliation(t *testing.T) { + initTestLogger() + tempDir := createTempDir(t) + + configData := makeConfigDataWithProvidersAndDir(nil, tempDir) + configData.Governance = &configstore.GovernanceConfig{ + PricingOverrides: []tables.TablePricingOverride{ + { + ID: "po-1", + Name: "Override One", + ScopeKind: "global", + MatchType: "exact", + Pattern: "gpt-4", + RequestTypes: []schemas.RequestType{ + schemas.ChatCompletionRequest, + }, + }, + }, + } + createConfigFile(t, tempDir, configData) + + ctx := context.Background() + + // First load: pricing override should be created in the DB + config1, err := LoadConfig(ctx, tempDir) + if err != nil { + t.Fatalf("First LoadConfig failed: %v", err) + } + + gov1, err := config1.ConfigStore.GetGovernanceConfig(ctx) + if err != nil { + t.Fatalf("Failed to get governance config after first load: %v", err) + } + if len(gov1.PricingOverrides) != 1 { + t.Fatalf("Expected 1 pricing override after first load, got %d", len(gov1.PricingOverrides)) + } + if gov1.PricingOverrides[0].ID != "po-1" { + t.Errorf("Expected pricing override ID 'po-1', got '%s'", gov1.PricingOverrides[0].ID) + } + if gov1.PricingOverrides[0].ConfigHash == "" { + t.Error("Pricing override hash not set after first load") + } + config1.Close(ctx) + + // Second load (unchanged config): should NOT fail with duplicate key error + config2, err := LoadConfig(ctx, tempDir) + if err != nil { + t.Fatalf("Second LoadConfig failed (duplicate key bug): %v", err) + } + + gov2, err := config2.ConfigStore.GetGovernanceConfig(ctx) + if err != nil { + t.Fatalf("Failed to get governance config after second load: %v", err) + } + if len(gov2.PricingOverrides) != 1 { + t.Fatalf("Expected 1 pricing override after second load, got %d", len(gov2.PricingOverrides)) + } + config2.Close(ctx) + + // Third load (updated config): should update the existing override, not create a duplicate + configData.Governance.PricingOverrides[0].Pattern = "gpt-4o" + createConfigFile(t, tempDir, configData) + + config3, err := LoadConfig(ctx, tempDir) + if err != nil { + t.Fatalf("Third LoadConfig failed: %v", err) + } + defer config3.Close(ctx) + + gov3, err := config3.ConfigStore.GetGovernanceConfig(ctx) + if err != nil { + t.Fatalf("Failed to get governance config after third load: %v", err) + } + if len(gov3.PricingOverrides) != 1 { + t.Fatalf("Expected 1 pricing override after update, got %d", len(gov3.PricingOverrides)) + } + if gov3.PricingOverrides[0].Pattern != "gpt-4o" { + t.Errorf("Pricing override pattern not updated: got '%s', want 'gpt-4o'", gov3.PricingOverrides[0].Pattern) + } + + t.Log("βœ“ Pricing overrides reconciliation works correctly (create, idempotent reload, update)") +} + // =================================================================================== // RUNTIME VS MIGRATION HASH PARITY TESTS (SQLite Integration) // =================================================================================== @@ -13253,9 +13355,8 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { t.Run("AzureKeyConfig_GORMRoundTrip", func(t *testing.T) { apiVersion := "2024-02-01" azureConfig := &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://myresource.openai.azure.com"), - APIVersion: schemas.NewEnvVar(apiVersion), - Deployments: map[string]string{"gpt-4": "gpt-4-deployment"}, + Endpoint: *schemas.NewEnvVar("https://myresource.openai.azure.com"), + APIVersion: schemas.NewEnvVar(apiVersion), } keyToSave := tables.TableKey{ @@ -13266,6 +13367,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { Value: *schemas.NewEnvVar("azure-key-value"), Weight: ptrFloat64(1.0), AzureKeyConfig: azureConfig, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, } schemaKey := schemas.Key{ @@ -13273,6 +13375,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { Value: keyToSave.Value, Weight: getWeight(keyToSave.Weight), AzureKeyConfig: keyToSave.AzureKeyConfig, + Aliases: keyToSave.Aliases, } hashBeforeSave, _ := configstore.GenerateKeyHash(schemaKey) @@ -13286,6 +13389,7 @@ func TestGenerateKeyHash_RuntimeVsMigrationParity(t *testing.T) { Value: keyFromDB.Value, Weight: getWeight(keyFromDB.Weight), AzureKeyConfig: keyFromDB.AzureKeyConfig, + Aliases: keyFromDB.Aliases, } hashAfterLoad, _ := configstore.GenerateKeyHash(schemaKeyFromDB) @@ -14069,11 +14173,9 @@ func TestSQLite_Key_UseForBatchAPIChange_Detected(t *testing.T) { } } -// TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit verifies that BudgetID and RateLimitID -// in VK provider configs affect hash generation. -func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { - budgetID1 := "budget-1" - budgetID2 := "budget-2" +// TestGenerateVirtualKeyHash_ProviderConfigRateLimit verifies that RateLimitID +// in VK provider configs affects hash generation. +func TestGenerateVirtualKeyHash_ProviderConfigRateLimit(t *testing.T) { rateLimitID1 := "rate-limit-1" rateLimitID2 := "rate-limit-2" weight := 1.0 @@ -14084,34 +14186,6 @@ func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { vk2 tables.TableVirtualKey expectEqual bool }{ - { - name: "different_budget_id_different_hash", - vk1: tables.TableVirtualKey{ - ID: "vk-1", - Name: "test-vk", - IsActive: true, - ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ - { - Provider: "openai", - Weight: &weight, - BudgetID: &budgetID1, - }, - }, - }, - vk2: tables.TableVirtualKey{ - ID: "vk-1", - Name: "test-vk", - IsActive: true, - ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ - { - Provider: "openai", - Weight: &weight, - BudgetID: &budgetID2, - }, - }, - }, - expectEqual: false, - }, { name: "different_rate_limit_id_different_hash", vk1: tables.TableVirtualKey{ @@ -14140,34 +14214,6 @@ func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { }, expectEqual: false, }, - { - name: "nil_vs_set_budget_id_different_hash", - vk1: tables.TableVirtualKey{ - ID: "vk-1", - Name: "test-vk", - IsActive: true, - ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ - { - Provider: "openai", - Weight: &weight, - BudgetID: nil, - }, - }, - }, - vk2: tables.TableVirtualKey{ - ID: "vk-1", - Name: "test-vk", - IsActive: true, - ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ - { - Provider: "openai", - Weight: &weight, - BudgetID: &budgetID1, - }, - }, - }, - expectEqual: false, - }, { name: "nil_vs_set_rate_limit_id_different_hash", vk1: tables.TableVirtualKey{ @@ -14197,7 +14243,7 @@ func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { expectEqual: false, }, { - name: "same_budget_and_rate_limit_same_hash", + name: "same_rate_limit_same_hash", vk1: tables.TableVirtualKey{ ID: "vk-1", Name: "test-vk", @@ -14206,7 +14252,6 @@ func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { { Provider: "openai", Weight: &weight, - BudgetID: &budgetID1, RateLimitID: &rateLimitID1, }, }, @@ -14219,7 +14264,6 @@ func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { { Provider: "openai", Weight: &weight, - BudgetID: &budgetID1, RateLimitID: &rateLimitID1, }, }, @@ -14250,107 +14294,6 @@ func TestGenerateVirtualKeyHash_ProviderConfigBudgetRateLimit(t *testing.T) { } } -// TestSQLite_VKProviderConfig_BudgetAndRateLimit verifies that BudgetID and RateLimitID -// in VK provider configs are properly persisted and retrieved from SQLite. -func TestSQLite_VKProviderConfig_BudgetAndRateLimit(t *testing.T) { - initTestLogger() - tempDir := createTempDir(t) - - budgetID := "budget-123" - rateLimitID := "rate-limit-456" - vkID := uuid.NewString() - weight := 1.0 - - // Create config with VK that has provider config with BudgetID and RateLimitID - configData := makeConfigDataFullWithDir( - nil, - map[string]configstore.ProviderConfig{ - "openai": { - Keys: []schemas.Key{ - { - ID: uuid.NewString(), - Name: "openai-key", - Value: *schemas.NewEnvVar("sk-test"), - Weight: 1, - }, - }, - }, - }, - &configstore.GovernanceConfig{ - Budgets: []tables.TableBudget{ - { - ID: budgetID, - MaxLimit: 100.0, - }, - }, - RateLimits: []tables.TableRateLimit{ - { - ID: rateLimitID, - RequestMaxLimit: int64Ptr(60), - TokenMaxLimit: int64Ptr(10000), - }, - }, - VirtualKeys: []tables.TableVirtualKey{ - { - ID: vkID, - Name: "test-vk", - Value: "vk-test-value", - IsActive: true, - ProviderConfigs: []tables.TableVirtualKeyProviderConfig{ - { - Provider: "openai", - Weight: &weight, - BudgetID: &budgetID, - RateLimitID: &rateLimitID, - }, - }, - }, - }, - }, - tempDir, - ) - - // Load config - createConfigFile(t, tempDir, configData) - config, err := LoadConfig(context.Background(), tempDir) - if err != nil { - t.Fatalf("LoadConfig failed: %v", err) - } - defer config.Close(context.Background()) - - // Verify the governance config has the VK with provider configs - if config.GovernanceConfig == nil { - t.Fatal("Expected GovernanceConfig to exist") - } - if len(config.GovernanceConfig.VirtualKeys) == 0 { - t.Fatal("Expected VirtualKeys in GovernanceConfig") - } - - // Find the VK and verify provider config - var foundVK *tables.TableVirtualKey - for i := range config.GovernanceConfig.VirtualKeys { - if config.GovernanceConfig.VirtualKeys[i].ID == vkID { - foundVK = &config.GovernanceConfig.VirtualKeys[i] - break - } - } - if foundVK == nil { - t.Fatalf("Virtual key %s not found in config", vkID) - } - - if len(foundVK.ProviderConfigs) == 0 { - t.Fatal("Expected VK to have provider configs") - } - - pc := foundVK.ProviderConfigs[0] - if pc.BudgetID == nil || *pc.BudgetID != budgetID { - t.Errorf("Expected BudgetID=%s, got %v", budgetID, pc.BudgetID) - } - if pc.RateLimitID == nil || *pc.RateLimitID != rateLimitID { - t.Errorf("Expected RateLimitID=%s, got %v", rateLimitID, pc.RateLimitID) - } -} - // intPtr is a helper to create a pointer to an int func intPtr(i int) *int { return &i @@ -14370,14 +14313,12 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), ProjectNumber: *schemas.NewEnvVar("123456789"), Region: *schemas.NewEnvVar("us-central1"), AuthCredentials: *schemas.NewEnvVar(`{"type":"service_account"}`), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -14386,14 +14327,12 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), ProjectNumber: *schemas.NewEnvVar("123456789"), Region: *schemas.NewEnvVar("us-central1"), AuthCredentials: *schemas.NewEnvVar(`{"type":"service_account"}`), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -14518,12 +14457,10 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -14532,13 +14469,10 @@ func TestKeyHashComparison_VertexConfigSyncScenarios(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-api-key-123"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint", "gemini-1.5-pro": "gemini-15-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project-123"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - "gemini-1.5-pro": "gemini-15-pro-endpoint", // Added! - }, }, } @@ -14922,11 +14856,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -14935,12 +14867,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", // Added - }, }, } @@ -14958,12 +14887,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment", "gpt-4o": "gpt-4o-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - "gpt-4o": "gpt-4o-deployment", - }, }, } @@ -14972,11 +14898,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", // gpt-4o removed - }, }, } @@ -14994,11 +14918,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment-v1"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment-v1", - }, }, } @@ -15007,11 +14929,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment-v2"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment-v2", // Value changed - }, }, } @@ -15030,8 +14950,7 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: nil, // No deployments + Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), }, } @@ -15040,11 +14959,9 @@ func TestKeyHashComparison_AzureDeploymentsChange(t *testing.T) { Name: "azure-key", Value: *schemas.NewEnvVar("azure-api-key"), Weight: 1, + Aliases: schemas.KeyAliases{"gpt-4": "gpt-4-deployment"}, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("https://myazure.openai.azure.com"), - Deployments: map[string]string{ - "gpt-4": "gpt-4-deployment", - }, }, } @@ -15065,13 +14982,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", - }, }, } @@ -15080,14 +14995,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", - "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5", // Added - }, }, } @@ -15105,14 +15017,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", - "claude-3.5": "arn:aws:bedrock:us-east-1::inference-profile/claude-3.5", - }, }, } @@ -15121,13 +15030,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3", // claude-3.5 removed - }, }, } @@ -15145,13 +15052,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-old"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-old", - }, }, } @@ -15160,13 +15065,11 @@ func TestKeyHashComparison_BedrockDeploymentsChange(t *testing.T) { Name: "bedrock-key", Value: *schemas.NewEnvVar("bedrock-key"), Weight: 1, + Aliases: schemas.KeyAliases{"claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-new"}, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("AKIAIOSFODNN7EXAMPLE"), SecretKey: *schemas.NewEnvVar("wJalrXUtnFEMI"), Region: schemas.NewEnvVar("us-east-1"), - Deployments: map[string]string{ - "claude-3": "arn:aws:bedrock:us-east-1::inference-profile/claude-3-new", // Value changed - }, }, } @@ -15187,12 +15090,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -15201,13 +15102,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint", "gemini-1.5-pro": "gemini-15-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - "gemini-1.5-pro": "gemini-15-pro-endpoint", // Added - }, }, } @@ -15225,13 +15123,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint", "gemini-1.5-pro": "gemini-15-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - "gemini-1.5-pro": "gemini-15-pro-endpoint", - }, }, } @@ -15240,12 +15135,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", // gemini-1.5-pro removed - }, }, } @@ -15263,12 +15156,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint-v1"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint-v1", - }, }, } @@ -15277,12 +15168,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint-v2"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint-v2", // Value changed - }, }, } @@ -15301,9 +15190,8 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, VertexKeyConfig: &schemas.VertexKeyConfig{ - ProjectID: *schemas.NewEnvVar("my-project"), - Region: *schemas.NewEnvVar("us-central1"), - Deployments: nil, // No deployments + ProjectID: *schemas.NewEnvVar("my-project"), + Region: *schemas.NewEnvVar("us-central1"), }, } @@ -15312,12 +15200,10 @@ func TestKeyHashComparison_VertexDeploymentsChange(t *testing.T) { Name: "vertex-key", Value: *schemas.NewEnvVar("vertex-creds"), Weight: 1, + Aliases: schemas.KeyAliases{"gemini-pro": "gemini-pro-endpoint"}, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("my-project"), Region: *schemas.NewEnvVar("us-central1"), - Deployments: map[string]string{ - "gemini-pro": "gemini-pro-endpoint", - }, }, } @@ -15407,12 +15293,13 @@ var enterpriseSchemaPaths = map[string]bool{ var excludedGoFields = map[string]map[string]bool{ // ClientConfig - MCP fields are managed at MCP level, not client level "configstore.ClientConfig": { - "ConfigHash": true, - "allowed_headers": true, // Internal use - "mcp_agent_depth": true, // Managed via MCP config - "mcp_code_mode_binding_level": true, - "mcp_tool_execution_timeout": true, - "mcp_tool_sync_interval": true, + "ConfigHash": true, + "allowed_headers": true, // Internal use + "mcp_agent_depth": true, // Managed via MCP config + "mcp_code_mode_binding_level": true, + "mcp_tool_execution_timeout": true, + "mcp_tool_sync_interval": true, + "mcp_disable_auto_tool_inject": true, }, "configstore.ProviderConfig": {"ConfigHash": true}, // GovernanceConfig - some fields are internal/enterprise @@ -15423,9 +15310,11 @@ var excludedGoFields = map[string]map[string]bool{ }, // Table types have DB-specific fields "tables.TableBudget": { - "config_hash": true, - "created_at": true, - "updated_at": true, + "config_hash": true, + "created_at": true, + "updated_at": true, + "virtual_key_id": true, // Internal DB FK for multi-budget ownership + "provider_config_id": true, // Internal DB FK for multi-budget ownership }, "tables.TableRateLimit": { "config_hash": true, @@ -15454,14 +15343,16 @@ var excludedGoFields = map[string]map[string]bool{ "config_hash": true, "created_at": true, "updated_at": true, - "budget": true, // GORM relation + "budgets": true, // GORM relation (budgets have virtual_key_id FK) "rate_limit": true, // GORM relation "team": true, // GORM relation "customer": true, // GORM relation }, "tables.TableVirtualKeyProviderConfig": { - "budget": true, // GORM relation - "rate_limit": true, // GORM relation + "rate_limit": true, // GORM relation + "allow_all_keys": true, // Internal DB field; users configure via key_ids + "keys": true, // GORM many2many relation; users configure via key_ids + "budgets": true, // GORM relation (budgets have provider_config_id FK) }, "tables.TableVirtualKeyMCPConfig": { "mcp_client": true, // GORM relation @@ -15505,7 +15396,8 @@ var excludedSchemaFields = map[string]map[string]bool{ "allowed_headers": true, // Not in ClientConfig }, "governance.virtual_keys.provider_configs": { - "keys": true, // Complex nested type, validated separately + "keys": true, // Complex nested type, validated separately + "key_ids": true, // Config-file format; handled via custom UnmarshalJSON into allow_all_keys/keys }, "mcp.client_configs": { "websocket_config": true, // Schema documents all connection types diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index ea0c8b0a72..36c56cc2e5 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -8,6 +8,7 @@ package lib import ( "context" + "fmt" "strconv" "strings" "time" @@ -31,6 +32,37 @@ const ( FastHTTPUserValueLargeResponseMode = "__bifrost_large_response_mode" ) +// ParseSessionIDFromBaggage extracts the session-id baggage member value. +// It supports simple W3C baggage parsing sufficient for log grouping. +func ParseSessionIDFromBaggage(header string) string { + for _, member := range strings.Split(header, ",") { + member = strings.TrimSpace(member) + if member == "" { + continue + } + + parts := strings.SplitN(member, ";", 2) + kv := strings.SplitN(strings.TrimSpace(parts[0]), "=", 2) + if len(kv) != 2 { + continue + } + + key := strings.ToLower(strings.TrimSpace(kv[0])) + value := strings.TrimSpace(kv[1]) + if key != "session-id" || value == "" { + continue + } + if len(value) > 255 { + if logger != nil { + logger.Warn("session-id exceeds 255 chars, ignoring: length=%d, prefix=%s", len(value), value[:255]) + } + continue + } + return value + } + return "" +} + // ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context, // preserving important header values for monitoring and tracing purposes. // @@ -92,7 +124,7 @@ const ( // // Maxim tracing data, MCP filters, governance keys, API keys, cache settings, // // session stickiness, and extra headers -func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher) (*schemas.BifrostContext, context.CancelFunc) { +func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, matcher *HeaderMatcher, mcpHeaderCombinedAllowlist schemas.WhiteList) (*schemas.BifrostContext, context.CancelFunc) { // Reuse a shared request-scoped context when available. var bifrostCtx *schemas.BifrostContext var cancel context.CancelFunc @@ -141,6 +173,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat maximTags := make(map[string]string) // Initialize extra headers map for headers prefixed with x-bf-eh- extraHeaders := make(map[string][]string) + // Initialize extra headers map for headers in the mcp header combined allowlist + mcpExtraHeaders := make(map[string][]string) // Security denylist of header names that should never be accepted (case-insensitive) // This denylist is always enforced regardless of user configuration securityDenylist := map[string]bool{ @@ -152,8 +186,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat "transfer-encoding": true, // prevent auth/key overrides via x-bf-eh-* - "x-api-key": true, - "x-goog-api-key": true, + "x-api-key": true, + "x-goog-api-key": true, "x-bf-api-key": true, "x-bf-api-key-id": true, "x-bf-vk": true, @@ -171,6 +205,12 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat // Then process other headers ctx.Request.Header.All()(func(key, value []byte) bool { keyStr := strings.ToLower(string(key)) + if keyStr == "baggage" { + if sessionID := ParseSessionIDFromBaggage(string(value)); sessionID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID) + } + return true + } if labelName, ok := strings.CutPrefix(keyStr, "x-bf-prom-"); ok { bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) return true @@ -377,6 +417,11 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat return true } } + // Handle MCP extra headers + if mcpHeaderCombinedAllowlist.IsAllowed(keyStr) { + mcpExtraHeaders[keyStr] = append(mcpExtraHeaders[keyStr], string(value)) + return true + } // Send back raw response header if keyStr == "x-bf-send-back-raw-response" { if valueStr := string(value); valueStr == "true" { @@ -411,6 +456,11 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat bifrostCtx.SetValue(schemas.BifrostContextKeyExtraHeaders, extraHeaders) } + // Store collected MCP extra headers in the context if any were found + if len(mcpExtraHeaders) > 0 { + bifrostCtx.SetValue(schemas.BifrostContextKeyMCPExtraHeaders, mcpExtraHeaders) + } + // Collect all request headers for downstream use (e.g., governance required headers check) // Keys are lowercased for case-insensitive lookup allHeaders := make(map[string]string) @@ -420,6 +470,21 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat }) bifrostCtx.SetValue(schemas.BifrostContextKeyRequestHeaders, allHeaders) + // Extract per-user MCP OAuth user identifier from X-Bf-User-Id header + if mcpUserID := string(ctx.Request.Header.Peek("X-Bf-User-Id")); mcpUserID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserID, mcpUserID) + } + + // Build and set OAuth redirect URI for per-user OAuth flows + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" + } + host := string(ctx.Host()) + if host != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyOAuthRedirectURI, fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host)) + } + if allowDirectKeys { // Extract API key from Authorization header (Bearer format), x-api-key, or x-goog-api-key header var apiKey string @@ -458,8 +523,8 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat key := schemas.Key{ ID: "header-provided", // Identifier for header-provided keys Value: *schemas.NewEnvVar(apiKey), - Models: []string{}, // Empty models list - will be validated by provider - Weight: 1.0, // Default weight + Models: schemas.WhiteList{"*"}, // Allow all models + Weight: 1.0, // Default weight } bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, key) } diff --git a/transports/bifrost-http/lib/ctx_test.go b/transports/bifrost-http/lib/ctx_test.go index 3f522a3548..abc0620883 100644 --- a/transports/bifrost-http/lib/ctx_test.go +++ b/transports/bifrost-http/lib/ctx_test.go @@ -10,13 +10,36 @@ import ( "github.com/valyala/fasthttp" ) +func TestParseSessionIDFromBaggage(t *testing.T) { + tests := []struct { + name string + header string + want string + }{ + {name: "single member", header: "session-id=abc", want: "abc"}, + {name: "multiple members", header: "foo=bar, session-id=abc, baz=qux", want: "abc"}, + {name: "member with properties", header: "session-id=abc;ttl=60", want: "abc"}, + {name: "spaces preserved around parsing", header: " foo=bar , session-id = abc123 ;ttl=60 ", want: "abc123"}, + {name: "missing member", header: "foo=bar", want: ""}, + {name: "malformed ignored", header: "session-id, foo=bar", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParseSessionIDFromBaggage(tt.header); got != tt.want { + t.Fatalf("ParseSessionIDFromBaggage(%q) = %q, want %q", tt.header, got, tt.want) + } + }) + } +} + func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} base := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) base.SetValue(schemas.BifrostContextKeyRequestID, "req-shared") ctx.SetUserValue(FastHTTPUserValueBifrostContext, base) - converted, cancel := ConvertToBifrostContext(ctx, false, nil) + converted, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancel() if converted == nil { @@ -36,13 +59,13 @@ func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { func TestConvertToBifrostContext_SecondCallReturnsSameSharedContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} - first, cancelFirst := ConvertToBifrostContext(ctx, false, nil) + first, cancelFirst := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancelFirst() if first == nil { t.Fatal("expected first context to be non-nil") } - second, cancelSecond := ConvertToBifrostContext(ctx, false, nil) + second, cancelSecond := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancelSecond() if second == nil { t.Fatal("expected second context to be non-nil") @@ -69,7 +92,7 @@ func TestConvertToBifrostContext_StarAllowlistSecurityHeadersBlocked(t *testing. ctx.Request.Header.Set("x-bf-eh-connection", "should-be-blocked") ctx.Request.Header.Set("x-bf-eh-proxy-authorization", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -103,7 +126,7 @@ func TestConvertToBifrostContext_StarAllowlistDirectForwardingSecurityBlocked(t // Security headers sent directly β€” should be blocked ctx.Request.Header.Set("proxy-authorization", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -140,7 +163,7 @@ func TestConvertToBifrostContext_PrefixWildcardDirectForwarding(t *testing.T) { // Header not matching the pattern ctx.Request.Header.Set("openai-version", "should-not-forward") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -168,7 +191,7 @@ func TestConvertToBifrostContext_WildcardAllowlistFiltering(t *testing.T) { ctx.Request.Header.Set("x-bf-eh-anthropic-version", "2024-01-01") ctx.Request.Header.Set("x-bf-eh-openai-version", "should-be-blocked") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -196,7 +219,7 @@ func TestConvertToBifrostContext_WildcardDenylistBlocking(t *testing.T) { ctx.Request.Header.Set("x-bf-eh-x-internal-secret", "blocked-value") ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, matcher, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -217,7 +240,7 @@ func TestConvertToBifrostContext_NilMatcher(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("x-bf-eh-custom-header", "allowed-value") - bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil) + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) defer cancel() extraHeaders, _ := bifrostCtx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) @@ -226,3 +249,27 @@ func TestConvertToBifrostContext_NilMatcher(t *testing.T) { t.Error("expected custom-header to be forwarded with nil matcher") } } + +func TestConvertToBifrostContext_BaggageSessionIDSetsGrouping(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.Set("baggage", "foo=bar, session-id=rt-123, baz=qux") + + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + defer cancel() + + if got, _ := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-123" { + t.Fatalf("parent request id = %q, want %q", got, "rt-123") + } +} + +func TestConvertToBifrostContext_EmptyBaggageSessionIDIgnored(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.Set("baggage", "session-id= ") + + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + defer cancel() + + if got := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID); got != nil { + t.Fatalf("parent request id should be unset, got %#v", got) + } +} diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 20894ee53a..3cdf2f31fa 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -11,6 +11,7 @@ import ( "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" + "github.com/maximhq/bifrost/plugins/prompts" "github.com/maximhq/bifrost/plugins/semanticcache" "github.com/maximhq/bifrost/plugins/telemetry" "github.com/maximhq/bifrost/transports/bifrost-http/handlers" @@ -62,6 +63,9 @@ func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifro } return telemetry.Init(telConfig, bifrostConfig.ModelCatalog, logger) + case prompts.PluginName: + return prompts.Init(ctx, bifrostConfig.ConfigStore, logger) + case logging.PluginName: loggingConfig, err := MarshalPluginConfig[logging.Config](pluginConfig) if err != nil { @@ -159,7 +163,15 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(telemetry.PluginName, builtinPlacement, schemas.Ptr(1)) - // 2. Logging (if enabled) + // 2. Prompts (requires config store for prompt repository) + if s.Config.ConfigStore != nil { + s.registerPluginWithStatus(ctx, prompts.PluginName, nil, nil, false) + } else { + s.markPluginDisabled(prompts.PluginName) + } + s.Config.SetPluginOrderInfo(prompts.PluginName, builtinPlacement, schemas.Ptr(2)) + + // 3. Logging (if enabled) if (s.Config.ClientConfig.EnableLogging == nil || *s.Config.ClientConfig.EnableLogging) && s.Config.LogsStore != nil { config := &logging.Config{ DisableContentLogging: &s.Config.ClientConfig.DisableContentLogging, @@ -169,60 +181,61 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } else { s.markPluginDisabled(logging.PluginName) } - s.Config.SetPluginOrderInfo(logging.PluginName, builtinPlacement, schemas.Ptr(2)) + s.Config.SetPluginOrderInfo(logging.PluginName, builtinPlacement, schemas.Ptr(3)) - // 3. Governance (if enabled and not enterprise) + // 4. Governance (if enabled and not enterprise) if ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil { config := &governance.Config{ - IsVkMandatory: &s.Config.ClientConfig.EnforceAuthOnInference, - RequiredHeaders: &s.Config.ClientConfig.RequiredHeaders, + IsVkMandatory: &s.Config.ClientConfig.EnforceAuthOnInference, + RequiredHeaders: &s.Config.ClientConfig.RequiredHeaders, + DisableAutoToolInject: &s.Config.ClientConfig.MCPDisableAutoToolInject, + RoutingChainMaxDepth: &s.Config.ClientConfig.RoutingChainMaxDepth, } s.registerPluginWithStatus(ctx, governance.PluginName, nil, config, false) } else { s.markPluginDisabled(governance.PluginName) } - s.Config.SetPluginOrderInfo(governance.PluginName, builtinPlacement, schemas.Ptr(3)) + s.Config.SetPluginOrderInfo(governance.PluginName, builtinPlacement, schemas.Ptr(4)) - // 4. OTEL (if configured in PluginConfigs) + // 5. OTEL (if configured in PluginConfigs) otelConfig := s.getPluginConfig(otel.PluginName) if otelConfig != nil && otelConfig.Enabled { s.registerPluginWithStatus(ctx, otel.PluginName, nil, otelConfig.Config, false) } else { s.markPluginDisabled(otel.PluginName) } - s.Config.SetPluginOrderInfo(otel.PluginName, builtinPlacement, schemas.Ptr(4)) + s.Config.SetPluginOrderInfo(otel.PluginName, builtinPlacement, schemas.Ptr(5)) - // 5. Semantic Cache (if configured in PluginConfigs) + // 6. Semantic Cache (if configured in PluginConfigs) semanticCacheConfig := s.getPluginConfig(semanticcache.PluginName) if semanticCacheConfig != nil && semanticCacheConfig.Enabled { s.registerPluginWithStatus(ctx, semanticcache.PluginName, nil, semanticCacheConfig.Config, false) } else { s.markPluginDisabled(semanticcache.PluginName) } - s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(5)) + s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6)) - // 6. Litellmcompat (if configured in PluginConfigs) + // 7. Litellmcompat (if configured in PluginConfigs) litellmcompatConfig := s.getPluginConfig(litellmcompat.PluginName) if litellmcompatConfig != nil && litellmcompatConfig.Enabled { s.registerPluginWithStatus(ctx, litellmcompat.PluginName, nil, litellmcompatConfig.Config, false) } else { s.markPluginDisabled(litellmcompat.PluginName) } - s.Config.SetPluginOrderInfo(litellmcompat.PluginName, builtinPlacement, schemas.Ptr(6)) + s.Config.SetPluginOrderInfo(litellmcompat.PluginName, builtinPlacement, schemas.Ptr(7)) - // 7. Maxim (if configured in PluginConfigs) + // 8. Maxim (if configured in PluginConfigs) maximConfig := s.getPluginConfig(maxim.PluginName) if maximConfig != nil && maximConfig.Enabled { s.registerPluginWithStatus(ctx, maxim.PluginName, nil, maximConfig.Config, false) } else { s.markPluginDisabled(maxim.PluginName) } - s.Config.SetPluginOrderInfo(maxim.PluginName, builtinPlacement, schemas.Ptr(7)) + s.Config.SetPluginOrderInfo(maxim.PluginName, builtinPlacement, schemas.Ptr(8)) return nil } - // loadCustomPlugins loads plugins from PluginConfigs func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { for _, cfg := range s.Config.PluginConfigs { diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 757d712510..46178d6680 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -9,8 +9,8 @@ import ( "net" "os" "os/signal" - "slices" "strings" + "sync" "syscall" "time" @@ -26,6 +26,7 @@ import ( "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" + "github.com/maximhq/bifrost/plugins/prompts" "github.com/maximhq/bifrost/plugins/semanticcache" "github.com/maximhq/bifrost/plugins/telemetry" "github.com/maximhq/bifrost/transports/bifrost-http/handlers" @@ -62,6 +63,8 @@ type ServerCallbacks interface { // Pricing related callbacks ReloadPricingManager(ctx context.Context) error ForceReloadPricing(ctx context.Context) error + UpsertPricingOverride(ctx context.Context, override *tables.TablePricingOverride) error + DeletePricingOverride(ctx context.Context, id string) error // Proxy related callbacks ReloadProxyConfig(ctx context.Context, config *tables.GlobalProxyConfig) error // Client config related callbacks @@ -89,7 +92,11 @@ type ServerCallbacks interface { AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error RemoveMCPClient(ctx context.Context, id string) error UpdateMCPClient(ctx context.Context, id string, updatedConfig *schemas.MCPClientConfig) error - UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error + UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error + // VerifyPerUserOAuthConnection verifies an MCP server using a temporary token and discovers tools. + VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) + // SetClientTools updates the tool map for an existing client. + SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) ReconnectMCPClient(ctx context.Context, id string) error // Logging related callbacks NewLogEntryAdded(ctx context.Context, logEntry *logstore.Log) error @@ -161,6 +168,10 @@ func (s *GovernanceInMemoryStore) GetConfiguredProviders() map[schemas.ModelProv return s.Config.Providers } +func (s *GovernanceInMemoryStore) GetMCPClientsAllowingAllVirtualKeys() map[string]string { + return s.Config.GetAllowOnAllVirtualKeysClients() +} + // AddMCPClient adds a new MCP client to the in-memory store func (s *BifrostHTTPServer) AddMCPClient(ctx context.Context, clientConfig *schemas.MCPClientConfig) error { if err := s.Config.AddMCPClient(ctx, clientConfig); err != nil { @@ -230,6 +241,21 @@ func (s *BifrostHTTPServer) RemoveMCPClient(ctx context.Context, id string) erro return nil } +// VerifyPerUserOAuthConnection delegates to the Bifrost client to verify an MCP +// server using a temporary access token and discover available tools. +func (s *BifrostHTTPServer) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + return s.Client.VerifyPerUserOAuthConnection(ctx, config, accessToken) +} + +// SetClientTools delegates to the Bifrost client to update tool map for an existing MCP client, +// then re-syncs the MCP server so the new tools are immediately visible via /mcp. +func (s *BifrostHTTPServer) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + s.Client.SetClientTools(clientID, tools, toolNameMapping) + if err := s.MCPServerHandler.SyncAllMCPServers(context.Background()); err != nil { + logger.Warn("failed to sync MCP servers after setting client tools: %v", err) + } +} + // ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. func (s *BifrostHTTPServer) ExecuteChatMCPTool(ctx context.Context, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { bifrostCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) @@ -243,7 +269,8 @@ func (s *BifrostHTTPServer) ExecuteResponsesMCPTool(ctx context.Context, toolCal } func (s *BifrostHTTPServer) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { - return s.Client.GetAvailableMCPTools(ctx) + bifrostCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return s.Client.GetAvailableMCPTools(bifrostCtx) } // markPluginDisabled marks a plugin as disabled in the plugin status @@ -498,9 +525,15 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas } } - // Syncing models (this part always runs regardless of governance) - if err := s.Config.ModelCatalog.SetProviderPricingOverrides(provider, providerInfo.PricingOverrides); err != nil { - logger.Warn("failed to refresh pricing overrides for provider %s: %v", provider, err) + // Read current key count from in-memory store (providerInfo.Keys is not preloaded from DB) + inMemoryKeys, _ := s.Config.GetProviderKeysRaw(provider) + isKeylessProvider := providerInfo.CustomProviderConfig != nil && providerInfo.CustomProviderConfig.IsKeyLess + hasNoKeys := len(inMemoryKeys) == 0 && !isKeylessProvider + + // Getting allowed models from all provider keys (needed before model listing) + providerKeys, err := s.Config.ConfigStore.GetKeysByProvider(ctx, string(provider)) + if err != nil { + return nil, fmt.Errorf("failed to update provider model catalog: failed to get keys by provider: %s", err) } bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) @@ -508,9 +541,30 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas bfCtx.SetValue(schemas.BifrostContextKeyValidateKeys, true) // Validate keys during provider add/update defer bfCtx.Cancel() - allModels, bifrostErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ - Provider: provider, - }) + // Run filtered and unfiltered model listing concurrently + var ( + allModels *schemas.BifrostListModelsResponse + bifrostErr *schemas.BifrostError + unfilteredModels *schemas.BifrostListModelsResponse + listModelsErr *schemas.BifrostError + listWg sync.WaitGroup + ) + listWg.Add(2) + go func() { + defer listWg.Done() + allModels, bifrostErr = s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ + Provider: provider, + }) + }() + go func() { + defer listWg.Done() + unfilteredModels, listModelsErr = s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ + Provider: provider, + Unfiltered: true, + }) + }() + listWg.Wait() + if allModels != nil && len(allModels.KeyStatuses) > 0 && s.Config.ConfigStore != nil { s.updateKeyStatus(ctx, allModels.KeyStatuses) } @@ -519,42 +573,36 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas s.updateKeyStatus(ctx, bifrostErr.ExtraFields.KeyStatuses) } - logger.Warn("failed to update provider model catalog: failed to list all models: %s. We are falling back onto the static datasheet", bifrost.GetErrorMessage(bifrostErr)) + if hasNoKeys { + logger.Warn("model discovery skipped for provider %s: no keys configured", provider) + } else { + logger.Warn("failed to update provider model catalog: failed to list all models: %s. We are falling back onto the static datasheet", bifrost.GetErrorMessage(bifrostErr)) + } // In case of error, we return an empty list of models, and fallback onto the static datasheet allModels = &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } } - // Getting allowed models from all provider keys - providerKeys, err := s.Config.ConfigStore.GetKeysByProvider(ctx, string(provider)) - if err != nil { - return nil, fmt.Errorf("failed to update provider model catalog: failed to get keys by provider: %s", err) - } - allowedInKeys := make([]schemas.Model, 0) - deniedInKeys := make([]schemas.Model, 0) + modelsInKeys := make([]schemas.Model, 0) for _, key := range providerKeys { - for _, model := range key.Models { - if !slices.Contains(key.BlacklistedModels, model) { - allowedInKeys = append(allowedInKeys, schemas.Model{ - ID: string(provider) + "/" + model, - }) - } + if key.Models.IsUnrestricted() { + continue } - for _, model := range key.BlacklistedModels { - deniedInKeys = append(deniedInKeys, schemas.Model{ + for _, model := range key.Models { + modelsInKeys = append(modelsInKeys, schemas.Model{ ID: string(provider) + "/" + model, }) } } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, allowedInKeys, deniedInKeys) - unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ - Provider: provider, - Unfiltered: true, - }) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, modelsInKeys) if listModelsErr != nil { - logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + if hasNoKeys { + logger.Warn("unfiltered model discovery skipped for provider %s: no keys configured", provider) + } else { + logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + } } else { - s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModelData) + s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModels) } return updatedProvider, nil } @@ -580,7 +628,6 @@ func (s *BifrostHTTPServer) RemoveProvider(ctx context.Context, provider schemas return fmt.Errorf("pricing manager not found") } s.Config.ModelCatalog.DeleteModelDataForProvider(provider) - s.Config.ModelCatalog.DeleteProviderPricingOverrides(provider) return nil } @@ -714,12 +761,13 @@ func (s *BifrostHTTPServer) UpdateDropExcessRequests(ctx context.Context, value s.Client.UpdateDropExcessRequests(value) } -// UpdateMCPToolManagerConfig updates the MCP tool manager config -func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { +// UpdateMCPToolManagerConfig updates the MCP tool manager config. +// Always pass the current disableAutoToolInject value so it is never reset. +func (s *BifrostHTTPServer) UpdateMCPToolManagerConfig(ctx context.Context, maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error { if s.Config == nil { return fmt.Errorf("config not found") } - return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds, codeModeBindingLevel) + return s.Client.UpdateToolManagerConfig(maxAgentDepth, toolExecutionTimeoutInSeconds, codeModeBindingLevel, disableAutoToolInject) } // reloadObservabilityPlugins reloads all observability plugins in the tracing middleware @@ -755,55 +803,68 @@ func (s *BifrostHTTPServer) ForceReloadPricing(ctx context.Context) error { return fmt.Errorf("failed to initialize new model catalog: %w", err) } s.Config.ModelCatalog = modelCatalog - for provider, providerConfig := range s.Config.Providers { - if err := s.Config.ModelCatalog.SetProviderPricingOverrides(provider, providerConfig.PricingOverrides); err != nil { - logger.Warn("failed to seed pricing overrides for provider %s: %v", provider, err) - } - } } else { if err := s.Config.ModelCatalog.ForceReloadPricing(ctx); err != nil { return fmt.Errorf("failed to force reload pricing: %w", err) } // Fetching keys for all providers and allowed models first // Based on allowed models we will set the data in the model catalog + var wg sync.WaitGroup for provider, providerConfig := range s.Config.Providers { - bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) - bfCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) - modelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ - Provider: provider, - }) - if listModelsErr != nil { - logger.Error("failed to list models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) - } - allowedModels := make([]schemas.Model, 0) - deniedModels := make([]schemas.Model, 0) - for _, key := range providerConfig.Keys { - for _, model := range key.Models { - if !slices.Contains(key.BlacklistedModels, model) { + wg.Add(1) + go func(provider schemas.ModelProvider, providerConfig configstore.ProviderConfig) { + defer wg.Done() + bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) + bfCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) + defer bfCtx.Cancel() + modelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ + Provider: provider, + }) + if listModelsErr != nil { + logger.Error("failed to list models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + } + allowedModels := make([]schemas.Model, 0) + for _, key := range providerConfig.Keys { + if key.Models.IsUnrestricted() { + continue + } + for _, model := range key.Models { allowedModels = append(allowedModels, schemas.Model{ ID: string(provider) + "/" + model, }) } } - for _, model := range key.BlacklistedModels { - deniedModels = append(deniedModels, schemas.Model{ - ID: string(provider) + "/" + model, - }) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels) + unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ + Provider: provider, + Unfiltered: true, + }) + if listModelsErr != nil { + logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + } else { + s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModelData) } - } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels, deniedModels) - unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ - Provider: provider, - Unfiltered: true, - }) - if listModelsErr != nil { - logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) - } else { - s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModelData) - } - bfCtx.Cancel() + }(provider, providerConfig) } + wg.Wait() + } + return nil +} + +// UpsertPricingOverride inserts or updates a pricing override in the in-memory model catalog. +func (s *BifrostHTTPServer) UpsertPricingOverride(ctx context.Context, override *tables.TablePricingOverride) error { + if s.Config == nil || s.Config.ModelCatalog == nil { + return fmt.Errorf("pricing manager not found") + } + return s.Config.ModelCatalog.UpsertPricingOverrides(override) +} + +// DeletePricingOverride removes a pricing override from the in-memory model catalog. +func (s *BifrostHTTPServer) DeletePricingOverride(ctx context.Context, id string) error { + if s.Config == nil || s.Config.ModelCatalog == nil { + return fmt.Errorf("pricing manager not found") } + s.Config.ModelCatalog.DeletePricingOverride(id) return nil } @@ -954,9 +1015,12 @@ func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlew // Initialize WebSocket pool and handler before integrations so it can be wired through s.wsPool = bfws.NewPool(s.Config.WebSocketConfig.Pool) wsResponsesHandler := handlers.NewWSResponsesHandler(s.Client, s.Config, s.wsPool) + wsRealtimeHandler := handlers.NewWSRealtimeHandler(s.Client, s.Config, s.wsPool) + webrtcRealtimeHandler := handlers.NewWebRTCRealtimeHandler(s.Client, s.Config) + realtimeClientSecretsHandler := handlers.NewRealtimeClientSecretsHandler(s.Client, s.Config) inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config) - s.IntegrationHandler = handlers.NewIntegrationHandler(s.Client, s.Config, wsResponsesHandler) + s.IntegrationHandler = handlers.NewIntegrationHandler(s.Client, s.Config, wsResponsesHandler, wsRealtimeHandler, webrtcRealtimeHandler, realtimeClientSecretsHandler) mcpInferenceHandler := handlers.NewMCPInferenceHandler(s.Client, s.Config) mcpServerHandler, err := handlers.NewMCPServerHandler(ctx, s.Config, s) if err != nil { @@ -998,6 +1062,10 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser if semanticCachePlugin != nil { cacheHandler = handlers.NewCacheHandler(semanticCachePlugin) } + var promptsReloader handlers.PromptCacheReloader + if promptsPlugin, err := lib.FindPluginAs[*prompts.Plugin](s.Config, prompts.PluginName); err == nil && promptsPlugin != nil { + promptsReloader = promptsPlugin + } // Websocket handler needs to go below UI handler logger.Debug("initializing websocket server") if s.WebSocketHandler == nil { @@ -1020,17 +1088,24 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser healthHandler := handlers.NewHealthHandler(s.Config) providerHandler := handlers.NewProviderHandler(callbacks, s.Config, s.Client) oauthHandler := handlers.NewOAuthHandler(s.Config.OAuthProvider, s.Client, s.Config) - mcpHandler := handlers.NewMCPHandler(callbacks, s.Client, s.Config, oauthHandler) + mcpHandler := handlers.NewMCPHandler(callbacks, callbacks, s.Client, s.Config, oauthHandler) configHandler := handlers.NewConfigHandler(callbacks, s.Config) pluginsHandler := handlers.NewPluginsHandler(callbacks, s.Config.ConfigStore) sessionHandler := handlers.NewSessionHandler(s.Config.ConfigStore, s.WSTicketStore) - promptsHandler := handlers.NewPromptsHandler(s.Config.ConfigStore) + promptsHandler := handlers.NewPromptsHandler(s.Config.ConfigStore, promptsReloader) // Going ahead with API handlers healthHandler.RegisterRoutes(s.Router, middlewares...) providerHandler.RegisterRoutes(s.Router, middlewares...) mcpHandler.RegisterRoutes(s.Router, middlewares...) configHandler.RegisterRoutes(s.Router, middlewares...) oauthHandler.RegisterRoutes(s.Router, middlewares...) + // OAuth metadata + per-user OAuth endpoints (no auth middleware β€” must be publicly accessible) + oauthMetadataHandler := handlers.NewOAuthMetadataHandler(s.Config) + oauthMetadataHandler.RegisterRoutes(s.Router) + perUserOAuthHandler := handlers.NewPerUserOAuthHandler(s.Config) + perUserOAuthHandler.RegisterRoutes(s.Router) + consentHandler := handlers.NewConsentHandler(s.Config) + consentHandler.RegisterRoutes(s.Router) if pluginsHandler != nil { pluginsHandler.RegisterRoutes(s.Router, middlewares...) } @@ -1256,50 +1331,51 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { if s.Config.ModelCatalog != nil { // Fetching keys for all providers and allowed models first // Based on allowed models we will set the data in the model catalog + var wg sync.WaitGroup for provider, providerConfig := range s.Config.Providers { - bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) - bfCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) - - modelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ - Provider: provider, - }) - if modelData != nil && len(modelData.KeyStatuses) > 0 && s.Config.ConfigStore != nil { - s.updateKeyStatus(ctx, modelData.KeyStatuses) - } - if listModelsErr != nil { - if len(listModelsErr.ExtraFields.KeyStatuses) > 0 && s.Config.ConfigStore != nil { - s.updateKeyStatus(ctx, listModelsErr.ExtraFields.KeyStatuses) + wg.Add(1) + go func(provider schemas.ModelProvider, providerConfig configstore.ProviderConfig) { + defer wg.Done() + bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) + bfCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) + defer bfCtx.Cancel() + + modelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ + Provider: provider, + }) + if modelData != nil && len(modelData.KeyStatuses) > 0 && s.Config.ConfigStore != nil { + s.updateKeyStatus(ctx, modelData.KeyStatuses) } - logger.Error("failed to list models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) - } - allowedModels := make([]schemas.Model, 0) - deniedModels := make([]schemas.Model, 0) - for _, key := range providerConfig.Keys { - for _, model := range key.Models { - if !slices.Contains(key.BlacklistedModels, model) { + if listModelsErr != nil { + if len(listModelsErr.ExtraFields.KeyStatuses) > 0 && s.Config.ConfigStore != nil { + s.updateKeyStatus(ctx, listModelsErr.ExtraFields.KeyStatuses) + } + logger.Error("failed to list models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + } + allowedModels := make([]schemas.Model, 0) + for _, key := range providerConfig.Keys { + if key.Models.IsUnrestricted() { + continue + } + for _, model := range key.Models { allowedModels = append(allowedModels, schemas.Model{ ID: string(provider) + "/" + model, }) } } - for _, model := range key.BlacklistedModels { - deniedModels = append(deniedModels, schemas.Model{ - ID: string(provider) + "/" + model, - }) + s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels) + unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ + Provider: provider, + Unfiltered: true, + }) + if listModelsErr != nil { + logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + } else { + s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModelData) } - } - s.Config.ModelCatalog.UpsertModelDataForProvider(provider, modelData, allowedModels, deniedModels) - unfilteredModelData, listModelsErr := s.Client.ListModelsRequest(bfCtx, &schemas.BifrostListModelsRequest{ - Provider: provider, - Unfiltered: true, - }) - if listModelsErr != nil { - logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) - } else { - s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModelData) - } - bfCtx.Cancel() + }(provider, providerConfig) } + wg.Wait() } logger.Info("models added to catalog") @@ -1336,19 +1412,24 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { if ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil && s.AuthMiddleware != nil { inferenceMiddlewares = append(inferenceMiddlewares, s.AuthMiddleware.InferenceMiddleware()) } - // Registering inference middlewares - inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) + // Once auth is done we will first add the Tracing middleware + // Always add tracing middleware when tracer is enabled - it creates traces and sets traceID in context + // The observability plugins are optional (can be empty if only logging is enabled) // Curating observability plugins observabilityPlugins := s.CollectObservabilityPlugins() // This enables the central streaming accumulator for both use cases // Initializing tracer with embedded streaming accumulator traceStore := tracing.NewTraceStore(60*time.Minute, logger) tracer := tracing.NewTracer(traceStore, s.Config.ModelCatalog, logger) + tracer.SetObservabilityPlugins(observabilityPlugins) s.Client.SetTracer(tracer) - // Always add tracing middleware when tracer is enabled - it creates traces and sets traceID in context - // The observability plugins are optional (can be empty if only logging is enabled) - s.TracingMiddleware = handlers.NewTracingMiddleware(tracer, observabilityPlugins) + s.TracingMiddleware = handlers.NewTracingMiddleware(tracer) + // TransportInterceptor must be inside TracingMiddleware so that the tracing defer + // runs AFTER transport post-hooks (capturing HTTPTransportPostHook plugin logs). + // Order: Tracing.pre β†’ TransportInterceptor.pre β†’ handler β†’ TransportInterceptor.post β†’ Tracing.defer + inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{handlers.TransportInterceptorMiddleware(s.Config)}, inferenceMiddlewares...) inferenceMiddlewares = append([]schemas.BifrostHTTPMiddleware{s.TracingMiddleware.Middleware()}, inferenceMiddlewares...) + err = s.RegisterInferenceRoutes(s.Ctx, inferenceMiddlewares...) if err != nil { if s.WSTicketStore != nil { @@ -1396,6 +1477,10 @@ func (s *BifrostHTTPServer) Start() error { select { case sig := <-sigChan: logger.Info("received signal %v, initiating graceful shutdown...", sig) + if s.IntegrationHandler != nil { + logger.Info("closing realtime transport sessions...") + s.IntegrationHandler.Close() + } // Create shutdown context with timeout shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -1452,6 +1537,9 @@ func (s *BifrostHTTPServer) Start() error { } case err := <-errChan: + if s.IntegrationHandler != nil { + s.IntegrationHandler.Close() + } if s.wsPool != nil { s.wsPool.Close() } diff --git a/transports/bifrost-http/websocket/session.go b/transports/bifrost-http/websocket/session.go index 314e2fd7bc..0f75b4b6a7 100644 --- a/transports/bifrost-http/websocket/session.go +++ b/transports/bifrost-http/websocket/session.go @@ -1,9 +1,13 @@ package websocket import ( + "strings" "sync" + "time" ws "github.com/fasthttp/websocket" + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" ) // Session tracks the binding between a client WebSocket connection and its upstream state. @@ -12,6 +16,8 @@ type Session struct { mu sync.RWMutex writeMu sync.Mutex // serializes all WriteMessage calls to clientConn + id string + // Client connection clientConn *ws.Conn @@ -22,16 +28,64 @@ type Session struct { // LastResponseID tracks the most recent response ID for previous_response_id chaining. lastResponseID string + // providerSessionID tracks the upstream provider's session identifier when exposed. + providerSessionID string + + // realtimeOutputText accumulates assistant/provider turn text until the terminal event. + realtimeOutputText string + + // realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the + // completed assistant turn can persist the full turn history instead of only the + // latest finalized input event. + realtimeTurnInputs []RealtimeTurnInput + + // realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been + // attached to a persisted turn, so late transcript updates do not pollute later turns. + realtimeConsumedTurnItemIDs map[string]struct{} + + // realtimeTurnHooks tracks the active turn-scoped plugin pipeline between + // response.create and response.done. + realtimeTurnHooks *RealtimeTurnPluginState + realtimeTurnBusy bool + closed bool } +type RealtimeToolOutput struct { + Summary string + Raw string +} + +type RealtimeTurnInput struct { + ItemID string + Role string + Summary string + Raw string +} + +type RealtimeTurnPluginState struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() + RequestID string + StartedAt time.Time + PreHookValues map[any]any +} + // NewSession creates a new session for a client WebSocket connection. func NewSession(clientConn *ws.Conn) *Session { return &Session{ + id: uuid.NewString(), clientConn: clientConn, } } +// ID returns the stable Bifrost session identifier for this websocket session. +func (s *Session) ID() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.id +} + // ClientConn returns the client's WebSocket connection. func (s *Session) ClientConn() *ws.Conn { return s.clientConn @@ -83,6 +137,212 @@ func (s *Session) LastResponseID() string { return s.lastResponseID } +// SetProviderSessionID stores the upstream provider session identifier when available. +func (s *Session) SetProviderSessionID(id string) { + s.mu.Lock() + defer s.mu.Unlock() + s.providerSessionID = id +} + +// ProviderSessionID returns the upstream provider session identifier when known. +func (s *Session) ProviderSessionID() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.providerSessionID +} + +// AppendRealtimeOutputText appends provider output content for the current realtime turn. +func (s *Session) AppendRealtimeOutputText(text string) { + if text == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeOutputText += text +} + +// ConsumeRealtimeOutputText returns the accumulated provider output and clears it. +func (s *Session) ConsumeRealtimeOutputText() string { + s.mu.Lock() + defer s.mu.Unlock() + text := s.realtimeOutputText + s.realtimeOutputText = "" + return text +} + +// AddRealtimeInput stores a finalized user turn event in arrival order. +func (s *Session) AddRealtimeInput(summary, raw string) { + if summary == "" && raw == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{ + Role: string(schemas.ChatMessageRoleUser), + Summary: summary, + Raw: raw, + }) +} + +// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID. +// Late updates for items already attached to a completed turn are ignored. +func (s *Session) RecordRealtimeInput(itemID, summary, raw string) { + s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw) +} + +// AddRealtimeToolOutput stores a pending tool result for the next assistant turn. +func (s *Session) AddRealtimeToolOutput(summary, raw string) { + if summary == "" && raw == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{ + Role: string(schemas.ChatMessageRoleTool), + Summary: summary, + Raw: raw, + }) +} + +// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID. +// Late updates for items already attached to a completed turn are ignored. +func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) { + s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw) +} + +func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) { + if summary == "" && raw == "" { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + itemID = strings.TrimSpace(itemID) + if itemID != "" { + if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed { + return + } + for idx := range s.realtimeTurnInputs { + if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role { + continue + } + if strings.TrimSpace(summary) != "" { + s.realtimeTurnInputs[idx].Summary = summary + } + if strings.TrimSpace(raw) != "" { + existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw) + incomingRaw := strings.TrimSpace(raw) + switch { + case existingRaw == "": + s.realtimeTurnInputs[idx].Raw = raw + case incomingRaw == "" || existingRaw == incomingRaw: + default: + s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw + } + } + return + } + } + + s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{ + ItemID: itemID, + Role: role, + Summary: summary, + Raw: raw, + }) +} + +// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them. +func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput { + s.mu.Lock() + defer s.mu.Unlock() + inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...) + if len(inputs) > 0 { + if s.realtimeConsumedTurnItemIDs == nil { + s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs)) + } + for _, input := range inputs { + if strings.TrimSpace(input.ItemID) != "" { + s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{} + } + } + } + s.realtimeTurnInputs = nil + return inputs +} + +// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them. +func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput { + s.mu.RLock() + defer s.mu.RUnlock() + return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...) +} + +// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline. +func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) { + s.mu.Lock() + defer s.mu.Unlock() + if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil { + s.realtimeTurnHooks.Cleanup() + } + s.realtimeTurnBusy = false + if s.closed { + if state != nil && state.Cleanup != nil { + state.Cleanup() + } + s.realtimeTurnHooks = nil + return + } + s.realtimeTurnHooks = state +} + +// TryBeginRealtimeTurnHooks reserves the single active turn slot. +func (s *Session) TryBeginRealtimeTurnHooks() bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil { + return false + } + s.realtimeTurnBusy = true + return true +} + +// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks. +func (s *Session) AbortRealtimeTurnHooks() { + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeTurnBusy = false +} + +// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it. +func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState { + s.mu.RLock() + defer s.mu.RUnlock() + return s.realtimeTurnHooks +} + +// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it. +func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState { + s.mu.Lock() + defer s.mu.Unlock() + state := s.realtimeTurnHooks + s.realtimeTurnHooks = nil + s.realtimeTurnBusy = false + return state +} + +// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline. +func (s *Session) ClearRealtimeTurnHooks() { + s.mu.Lock() + defer s.mu.Unlock() + if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil { + s.realtimeTurnHooks.Cleanup() + } + s.realtimeTurnHooks = nil + s.realtimeTurnBusy = false +} + // Close closes the session and its upstream connection if pinned. func (s *Session) Close() { s.mu.Lock() @@ -91,6 +351,16 @@ func (s *Session) Close() { return } s.closed = true + if s.realtimeTurnHooks != nil { + if s.realtimeTurnHooks.Cleanup != nil { + s.realtimeTurnHooks.Cleanup() + } + s.realtimeTurnHooks = nil + } + s.realtimeTurnBusy = false + if s.clientConn != nil { + _ = s.clientConn.Close() + } if s.upstream != nil { s.upstream.Close() s.upstream = nil @@ -166,3 +436,15 @@ func (m *SessionManager) CloseAll() { session.Close() } } + +// Snapshot returns a copy of the currently tracked sessions. +func (m *SessionManager) Snapshot() []*Session { + m.mu.RLock() + defer m.mu.RUnlock() + + sessions := make([]*Session, 0, len(m.sessions)) + for _, session := range m.sessions { + sessions = append(sessions, session) + } + return sessions +} diff --git a/transports/bifrost-http/websocket/session_test.go b/transports/bifrost-http/websocket/session_test.go index 8c7a6ebb1f..148e6fe1d5 100644 --- a/transports/bifrost-http/websocket/session_test.go +++ b/transports/bifrost-http/websocket/session_test.go @@ -1,136 +1,156 @@ package websocket import ( - "net/http" - "net/http/httptest" - "strings" "testing" ws "github.com/fasthttp/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func dialTestWS(t *testing.T, server *httptest.Server) *ws.Conn { - t.Helper() - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - conn, _, err := ws.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - return conn -} +func TestSessionManagerCreateAndGet(t *testing.T) { + manager := NewSessionManager(2) + conn := newTestConn() -func startEchoServer(t *testing.T) *httptest.Server { - t.Helper() - upgrader := ws.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, - } - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - mt, msg, err := conn.ReadMessage() - if err != nil { - break - } - conn.WriteMessage(mt, msg) - } - })) + session, err := manager.Create(conn) + if err != nil { + t.Fatalf("Create() unexpected error: %v", err) + } + if session == nil { + t.Fatal("Create() returned nil session") + } + if got := manager.Get(conn); got != session { + t.Fatal("Get() did not return the created session") + } + if got := manager.Count(); got != 1 { + t.Fatalf("Count() = %d, want 1", got) + } } -func TestSessionManager_CreateAndGet(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionManagerConnectionLimit(t *testing.T) { + manager := NewSessionManager(1) - sm := NewSessionManager(10) - - conn := dialTestWS(t, server) - defer conn.Close() - - session, err := sm.Create(conn) - require.NoError(t, err) - require.NotNil(t, session) - - got := sm.Get(conn) - assert.Equal(t, session, got) - assert.Equal(t, 1, sm.Count()) + if _, err := manager.Create(newTestConn()); err != nil { + t.Fatalf("first Create() unexpected error: %v", err) + } + if _, err := manager.Create(newTestConn()); err != ErrConnectionLimitReached { + t.Fatalf("second Create() error = %v, want %v", err, ErrConnectionLimitReached) + } } -func TestSessionManager_ConnectionLimit(t *testing.T) { - server := startEchoServer(t) - defer server.Close() - - sm := NewSessionManager(2) +func TestSessionManagerRemove(t *testing.T) { + manager := NewSessionManager(2) + conn := newTestConn() - conn1 := dialTestWS(t, server) - defer conn1.Close() - conn2 := dialTestWS(t, server) - defer conn2.Close() - conn3 := dialTestWS(t, server) - defer conn3.Close() + session, err := manager.Create(conn) + if err != nil { + t.Fatalf("Create() unexpected error: %v", err) + } - _, err := sm.Create(conn1) - require.NoError(t, err) - _, err = sm.Create(conn2) - require.NoError(t, err) + manager.Remove(conn) - // Third should fail - _, err = sm.Create(conn3) - assert.ErrorIs(t, err, ErrConnectionLimitReached) - assert.Equal(t, 2, sm.Count()) + if got := manager.Get(conn); got != nil { + t.Fatal("Get() should return nil after Remove()") + } + if got := manager.Count(); got != 0 { + t.Fatalf("Count() = %d, want 0", got) + } + if !session.closed { + t.Fatal("expected removed session to be closed") + } } -func TestSessionManager_Remove(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionLastResponseID(t *testing.T) { + session := NewSession(newTestConn()) + session.SetLastResponseID("resp-123") - sm := NewSessionManager(10) + if got := session.LastResponseID(); got != "resp-123" { + t.Fatalf("LastResponseID() = %q, want %q", got, "resp-123") + } +} - conn := dialTestWS(t, server) - defer conn.Close() +func TestSessionManagerCloseAll(t *testing.T) { + manager := NewSessionManager(4) + connA := newTestConn() + connB := newTestConn() - _, err := sm.Create(conn) - require.NoError(t, err) - assert.Equal(t, 1, sm.Count()) + sessionA, err := manager.Create(connA) + if err != nil { + t.Fatalf("Create(connA) unexpected error: %v", err) + } + sessionB, err := manager.Create(connB) + if err != nil { + t.Fatalf("Create(connB) unexpected error: %v", err) + } + + manager.CloseAll() - sm.Remove(conn) - assert.Equal(t, 0, sm.Count()) - assert.Nil(t, sm.Get(conn)) + if got := manager.Count(); got != 0 { + t.Fatalf("Count() = %d, want 0", got) + } + if !sessionA.closed || !sessionB.closed { + t.Fatal("expected all sessions to be closed") + } } -func TestSession_LastResponseID(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionRealtimeState(t *testing.T) { + session := NewSession(newTestConn()) + if session.ID() == "" { + t.Fatal("expected session ID to be populated") + } - conn := dialTestWS(t, server) - defer conn.Close() + session.SetProviderSessionID("provider-session-1") + if got := session.ProviderSessionID(); got != "provider-session-1" { + t.Fatalf("ProviderSessionID() = %q, want %q", got, "provider-session-1") + } - session := NewSession(conn) - assert.Equal(t, "", session.LastResponseID()) + session.AppendRealtimeOutputText("hello") + session.AppendRealtimeOutputText(" world") + if got := session.ConsumeRealtimeOutputText(); got != "hello world" { + t.Fatalf("ConsumeRealtimeOutputText() = %q, want %q", got, "hello world") + } + if got := session.ConsumeRealtimeOutputText(); got != "" { + t.Fatalf("ConsumeRealtimeOutputText() after clear = %q, want empty string", got) + } - session.SetLastResponseID("resp_123") - assert.Equal(t, "resp_123", session.LastResponseID()) + session.AddRealtimeInput("hello", `{"type":"conversation.item.create","item":{"role":"user"}}`) + session.AddRealtimeToolOutput("tool result", `{"type":"conversation.item.create","item":{"type":"function_call_output"}}`) + turnInputs := session.ConsumeRealtimeTurnInputs() + if len(turnInputs) != 2 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 2", len(turnInputs)) + } + if turnInputs[0].Role != "user" || turnInputs[0].Summary != "hello" { + t.Fatalf("turnInputs[0] = %+v, want user hello", turnInputs[0]) + } + if turnInputs[1].Role != "tool" || turnInputs[1].Summary != "tool result" { + t.Fatalf("turnInputs[1] = %+v, want tool result", turnInputs[1]) + } + if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) after clear = %d, want 0", len(got)) + } } -func TestSessionManager_CloseAll(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionRecordRealtimeInputUpdatesPendingItemAndIgnoresConsumedLateUpdate(t *testing.T) { + session := NewSession(newTestConn()) - sm := NewSessionManager(10) + session.RecordRealtimeInput("item_1", "[Audio transcription unavailable]", `{"type":"conversation.item.done","item":{"id":"item_1"}}`) + session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`) - conn1 := dialTestWS(t, server) - defer conn1.Close() - conn2 := dialTestWS(t, server) - defer conn2.Close() + turnInputs := session.ConsumeRealtimeTurnInputs() + if len(turnInputs) != 1 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 1", len(turnInputs)) + } + if turnInputs[0].ItemID != "item_1" { + t.Fatalf("turnInputs[0].ItemID = %q, want %q", turnInputs[0].ItemID, "item_1") + } + if turnInputs[0].Summary != "Hello." { + t.Fatalf("turnInputs[0].Summary = %q, want %q", turnInputs[0].Summary, "Hello.") + } - _, err := sm.Create(conn1) - assert.NoError(t, err) - _, err = sm.Create(conn2) - assert.NoError(t, err) - assert.Equal(t, 2, sm.Count()) + session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`) + if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) after late consumed update = %d, want 0", len(got)) + } +} - sm.CloseAll() - assert.Equal(t, 0, sm.Count()) +func newTestConn() *ws.Conn { + return &ws.Conn{} } diff --git a/transports/changelog.md b/transports/changelog.md index e69de29bb2..5c565957e1 100644 --- a/transports/changelog.md +++ b/transports/changelog.md @@ -0,0 +1,30 @@ +## ✨ Features + +- **Realtime Support** β€” Add WebSocket, WebRTC, and client secret handlers with session state management and transport context helpers +- **Fireworks AI Provider** β€” Add Fireworks AI as a first-class provider with native completions, responses, embeddings, and image generations (thanks [@ivanetchart](https://github.com/ivanetchart)!) +- **Per-User OAuth Consent** β€” Add per-user OAuth consent flow with identity selection and MCP authentication +- **Prompts Plugin** β€” New prompts plugin with direct key header resolver and selective message inclusion when committing prompt sessions +- **Access Profiles** β€” Add access profiles for fine-grained permission control +- **Bedrock Embeddings & Image Gen** β€” Add embeddings, image gen, edit and variation support to Bedrock +- **EnvVar Improvements** β€” Add IsSet method to EnvVar and auto-redact env-backed values in JSON serialization +- **Logging Tracking Fields** β€” Add support for tracking userId, teamId, customerId, and businessUnitId in logging +- **Virtual Keys Export** β€” Add sorting and CSV export to virtual keys table +- **Path Whitelisting** β€” Allow path whitelisting from security config +- **Server Bootstrap Timer** β€” Add server bootstrap timer for startup diagnostics + +## 🐞 Fixed + +- **Bedrock Tool Choice** β€” Fix bedrock tool choice conversion to auto +- **Bedrock Streaming Retries** β€” Retry retryable AWS exceptions and stale/closed-connection errors in bedrock streaming +- **Bedrock SigV4 Service** β€” Correct SigV4 service name for agent runtime rerank +- **MCP Tool Logs** β€” Fix MCP tool logs not being captured correctly +- **Routing Rule Targets** β€” Preserve routing rule targets for genai and bedrock paths +- **Provider Budget Duplication** β€” Fix provider level multiline budget duplication issue +- **Vertex Endpoint** β€” Fix vertex endpoint correction +- **Gemini Thinking Budget** β€” Fix thinking budget validation for gemini models +- **SQLite Migrations** β€” Fix SQLite migration connections, error handling, and disable foreign key checks during migration +- **Tool Parameter Schemas** β€” Preserve explicit empty tool parameter schemas for openai passthrough +- **List Models Output** β€” Include raw model ID in list-models output alongside aliases +- **Config Schema** β€” Fix config schema for bedrock key config +- **Data Race Fix** β€” Fix race in data reading from fasthttp request for integrations +- **Model Listing** β€” Unify /api/models and /api/models/details listing behavior diff --git a/transports/config.schema.json b/transports/config.schema.json index ecbf788437..1e5070d76f 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -172,7 +172,10 @@ }, "mcp_code_mode_binding_level": { "type": "string", - "enum": ["server", "tool"], + "enum": [ + "server", + "tool" + ], "description": "Code mode binding level for MCP tools" }, "mcp_tool_sync_interval": { @@ -180,6 +183,17 @@ "minimum": 0, "description": "Global tool sync interval in minutes (0 = disabled)", "default": 10 + }, + "mcp_disable_auto_tool_inject": { + "type": "boolean", + "description": "When true, MCP tools are not automatically injected into requests. Tools are only included when explicitly specified via request context filters or headers, such as x-bf-mcp-include-tools or x-bf-mcp-include-clients.", + "default": false + }, + "routing_chain_max_depth": { + "type": "integer", + "minimum": 1, + "description": "Maximum depth for routing rule chain evaluation", + "default": 10 } }, "additionalProperties": false @@ -219,7 +233,7 @@ "$ref": "#/$defs/provider" }, "ollama": { - "$ref": "#/$defs/provider" + "$ref": "#/$defs/provider_with_ollama_config" }, "groq": { "$ref": "#/$defs/provider" @@ -231,7 +245,7 @@ "$ref": "#/$defs/provider" }, "sgl": { - "$ref": "#/$defs/provider" + "$ref": "#/$defs/provider_with_sgl_config" }, "parasail": { "$ref": "#/$defs/provider" @@ -240,7 +254,7 @@ "$ref": "#/$defs/provider" }, "replicate": { - "$ref": "#/$defs/provider" + "$ref": "#/$defs/provider_with_replicate_config" }, "elevenlabs": { "$ref": "#/$defs/provider" @@ -292,10 +306,13 @@ "format": "date-time", "description": "Last time budget was reset" }, - "calendar_aligned": { - "type": "boolean", - "description": "Snap resets to calendar boundaries (day/week/month/year start)", - "default": false + "virtual_key_id": { + "type": "string", + "description": "ID of the virtual key this budget belongs to (mutually exclusive with provider_config_id)" + }, + "provider_config_id": { + "type": "integer", + "description": "ID of the provider config this budget belongs to (mutually exclusive with virtual_key_id)" } }, "required": [ @@ -462,6 +479,11 @@ "description": "Whether the virtual key is active", "default": true }, + "calendar_aligned": { + "type": "boolean", + "description": "Snap all budget resets to calendar boundaries (day, week, month, year)", + "default": false + }, "team_id": { "type": "string", "description": "Associated team ID (mutually exclusive with customer_id)" @@ -470,24 +492,20 @@ "type": "string", "description": "Associated customer ID (mutually exclusive with team_id)" }, - "budget_id": { - "type": "string", - "description": "Associated budget ID" - }, "rate_limit_id": { "type": "string", "description": "Associated rate limit ID" }, "provider_configs": { "type": "array", - "description": "Provider configurations for this virtual key (empty means all providers allowed)", + "description": "Provider configurations for this virtual key (empty means no providers allowed, deny-by-default)", "items": { "$ref": "#/$defs/virtual_key_provider_config" } }, "mcp_configs": { "type": "array", - "description": "MCP configurations for this virtual key", + "description": "MCP configurations for this virtual key (empty array means no MCP tools allowed, deny-by-default)", "items": { "$ref": "#/$defs/virtual_key_mcp_config" } @@ -507,6 +525,13 @@ "$ref": "#/$defs/routing_rule" } }, + "pricing_overrides": { + "type": "array", + "description": "Scoped pricing overrides applied at runtime by the model catalog", + "items": { + "$ref": "#/$defs/pricing_override" + } + }, "auth_config": { "$ref": "#/$defs/auth_config" }, @@ -575,7 +600,9 @@ "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" } }, - "required": ["name"] + "required": [ + "name" + ] } } }, @@ -910,7 +937,7 @@ }, "name": { "type": "string", - "description": "Name of the plugin (built-in: telemetry, logging, governance, maxim, semantic_cache, otel, or custom plugin name)" + "description": "Name of the plugin (built-in: telemetry, prompts, logging, governance, maxim, semantic_cache, otel, or custom plugin name)" }, "config": { "type": "object", @@ -930,7 +957,10 @@ }, "placement": { "type": "string", - "enum": ["pre_builtin", "post_builtin"], + "enum": [ + "pre_builtin", + "post_builtin" + ], "description": "Whether this plugin runs before or after built-in plugins. Default: post_builtin", "optional": true, "default": "post_builtin" @@ -1010,10 +1040,15 @@ "description": "Password for basic authentication" } }, - "required": ["username", "password"] + "required": [ + "username", + "password" + ] } }, - "required": ["push_gateway_url"] + "required": [ + "push_gateway_url" + ] } }, "additionalProperties": false @@ -1475,7 +1510,9 @@ "maximum": 1 } }, - "required": ["weight"], + "required": [ + "weight" + ], "additionalProperties": false }, "routing_rule": { @@ -1522,12 +1559,17 @@ }, "scope": { "type": "string", - "enum": ["global", "team", "customer", "virtual_key"], + "enum": [ + "global", + "team", + "customer", + "virtual_key" + ], "description": "Rule scope level", "default": "global" }, "scope_id": { - "type": ["string", "null"], + "type": "string", "description": "Entity ID for non-global scopes (required for non-global scope)" }, "priority": { @@ -1541,8 +1583,37 @@ "additionalProperties": true } }, - "required": ["id", "name", "targets"], - "additionalProperties": false + "required": [ + "id", + "name", + "targets" + ], + "additionalProperties": false, + "if": { + "properties": { + "scope": { + "enum": [ + "team", + "customer", + "virtual_key" + ] + } + }, + "required": [ + "scope" + ] + }, + "then": { + "required": [ + "scope_id" + ], + "properties": { + "scope_id": { + "type": "string", + "minLength": 1 + } + } + } }, "virtual_key_provider_config": { "type": "object", @@ -1561,28 +1632,27 @@ "description": "Provider name" }, "weight": { - "type": "number", - "description": "Weight for load balancing", - "default": 1.0 + "type": [ + "number", + "null" + ], + "description": "Weight for load balancing (null opts out of weighted routing)", + "default": null }, "allowed_models": { "type": "array", - "description": "Allowed models for this provider config (empty means all models allowed)", + "description": "Allowed models for this provider config. Use [\"*\"] to allow all models; empty array denies all (deny-by-default).", "items": { "type": "string" } }, - "budget_id": { - "type": "string", - "description": "Associated budget ID" - }, "rate_limit_id": { "type": "string", "description": "Associated rate limit ID" }, - "keys": { + "key_ids": { "type": "array", - "description": "Provider keys for this config (empty means all keys allowed for this provider)", + "description": "Keys allowed for this provider config. Use [\"*\"] to allow all keys; empty array denies all (deny-by-default). In config.json, values are key names. Via the API, values are key UUIDs.", "items": { "type": "object", "properties": { @@ -1851,7 +1921,7 @@ }, "tools_to_execute": { "type": "array", - "description": "Tools to execute for this MCP config", + "description": "Include-only list of tools this Virtual Key is permitted to execute from this MCP client. ['*'] means all tools allowed, [] means no tools allowed (deny-by-default).", "items": { "type": "string" } @@ -1904,167 +1974,74 @@ }, "additionalProperties": false }, - "pricing_override_match_type": { - "type": "string", - "enum": [ - "exact", - "wildcard", - "regex" - ] - }, - "pricing_override_request_type": { - "type": "string", - "enum": [ - "text_completion", - "text_completion_stream", - "chat_completion", - "chat_completion_stream", - "responses", - "responses_stream", - "embedding", - "rerank", - "speech", - "speech_stream", - "transcription", - "transcription_stream", - "image_generation", - "image_generation_stream" - ] - }, - "provider_pricing_override": { + "network_config": { "type": "object", "properties": { - "model_pattern": { + "base_url": { "type": "string", - "minLength": 1 + "format": "uri", + "description": "Base URL for the provider (optional, required for Ollama)" }, - "match_type": { - "$ref": "#/$defs/pricing_override_match_type" + "extra_headers": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "Additional headers to send with requests" }, - "request_types": { - "type": "array", - "items": { - "$ref": "#/$defs/pricing_override_request_type" - } + "default_request_timeout_in_seconds": { + "type": "integer", + "minimum": 1, + "description": "Default request timeout in seconds" }, - "input_cost_per_token": { "type": "number", "minimum": 0 }, - "output_cost_per_token": { "type": "number", "minimum": 0 }, - "input_cost_per_video_per_second": { "type": "number", "minimum": 0 }, - "input_cost_per_audio_per_second": { "type": "number", "minimum": 0 }, - "input_cost_per_character": { "type": "number", "minimum": 0 }, - "output_cost_per_character": { "type": "number", "minimum": 0 }, - "input_cost_per_token_above_128k_tokens": { "type": "number", "minimum": 0 }, - "input_cost_per_character_above_128k_tokens": { "type": "number", "minimum": 0 }, - "input_cost_per_image_above_128k_tokens": { "type": "number", "minimum": 0 }, - "input_cost_per_video_per_second_above_128k_tokens": { "type": "number", "minimum": 0 }, - "input_cost_per_audio_per_second_above_128k_tokens": { "type": "number", "minimum": 0 }, - "output_cost_per_token_above_128k_tokens": { "type": "number", "minimum": 0 }, - "output_cost_per_character_above_128k_tokens": { "type": "number", "minimum": 0 }, - "input_cost_per_token_above_200k_tokens": { "type": "number", "minimum": 0 }, - "output_cost_per_token_above_200k_tokens": { "type": "number", "minimum": 0 }, - "cache_creation_input_token_cost_above_200k_tokens": { "type": "number", "minimum": 0 }, - "cache_read_input_token_cost_above_200k_tokens": { "type": "number", "minimum": 0 }, - "cache_read_input_token_cost": { "type": "number", "minimum": 0 }, - "cache_creation_input_token_cost": { "type": "number", "minimum": 0 }, - "input_cost_per_token_batches": { "type": "number", "minimum": 0 }, - "output_cost_per_token_batches": { "type": "number", "minimum": 0 }, - "input_cost_per_image_token": { "type": "number", "minimum": 0 }, - "output_cost_per_image_token": { "type": "number", "minimum": 0 }, - "input_cost_per_image": { "type": "number", "minimum": 0 }, - "output_cost_per_image": { "type": "number", "minimum": 0 }, - "cache_read_input_image_token_cost": { "type": "number", "minimum": 0 } - }, - "required": [ - "model_pattern", - "match_type" - ], - "additionalProperties": false - }, - "custom_provider_config": { - "type": "object", - "description": "Custom provider configuration for extending or customizing provider behavior", - "properties": { - "is_key_less": { + "max_retries": { + "type": "integer", + "minimum": 0, + "description": "Maximum number of retries" + }, + "retry_backoff_initial_ms": { + "type": "integer", + "minimum": 0, + "description": "Initial retry backoff in milliseconds" + }, + "retry_backoff_max_ms": { + "type": "integer", + "minimum": 0, + "description": "Maximum retry backoff in milliseconds" + }, + "insecure_skip_verify": { "type": "boolean", - "description": "Whether the custom provider requires a key" + "description": "Disable TLS certificate verification for provider connections. This bypasses server certificate validation and should be used only as a last resort when a trusted CA chain cannot be configured. Prefer ca_cert_pem for self-signed or private CA deployments." }, - "base_provider_type": { + "ca_cert_pem": { "type": "string", - "description": "Base provider type to extend" + "description": "PEM-encoded CA certificate to trust for provider endpoint connections (e.g. self-signed or internal CA)" }, - "allowed_requests": { - "type": "object", - "description": "Allowed request types for the custom provider", - "properties": { - "list_models": { "type": "boolean" }, - "text_completion": { "type": "boolean" }, - "text_completion_stream": { "type": "boolean" }, - "chat_completion": { "type": "boolean" }, - "chat_completion_stream": { "type": "boolean" }, - "responses": { "type": "boolean" }, - "responses_stream": { "type": "boolean" }, - "count_tokens": { "type": "boolean" }, - "embedding": { "type": "boolean" }, - "rerank": { "type": "boolean" }, - "speech": { "type": "boolean" }, - "speech_stream": { "type": "boolean" }, - "transcription": { "type": "boolean" }, - "transcription_stream": { "type": "boolean" }, - "image_generation": { "type": "boolean" }, - "image_generation_stream": { "type": "boolean" }, - "image_edit": { "type": "boolean" }, - "image_edit_stream": { "type": "boolean" }, - "image_variation": { "type": "boolean" }, - "video_generation": { "type": "boolean" }, - "video_retrieve": { "type": "boolean" }, - "video_download": { "type": "boolean" }, - "video_delete": { "type": "boolean" }, - "video_list": { "type": "boolean" }, - "video_remix": { "type": "boolean" }, - "batch_create": { "type": "boolean" }, - "batch_list": { "type": "boolean" }, - "batch_retrieve": { "type": "boolean" }, - "batch_cancel": { "type": "boolean" }, - "batch_delete": { "type": "boolean" }, - "batch_results": { "type": "boolean" }, - "file_upload": { "type": "boolean" }, - "file_list": { "type": "boolean" }, - "file_retrieve": { "type": "boolean" }, - "file_delete": { "type": "boolean" }, - "file_content": { "type": "boolean" }, - "container_create": { "type": "boolean" }, - "container_list": { "type": "boolean" }, - "container_retrieve": { "type": "boolean" }, - "container_delete": { "type": "boolean" }, - "container_file_create": { "type": "boolean" }, - "container_file_list": { "type": "boolean" }, - "container_file_retrieve": { "type": "boolean" }, - "container_file_content": { "type": "boolean" }, - "container_file_delete": { "type": "boolean" }, - "passthrough": { "type": "boolean" }, - "passthrough_stream": { "type": "boolean" } - }, - "additionalProperties": false + "stream_idle_timeout_in_seconds": { + "type": "integer", + "minimum": 5, + "maximum": 3600, + "description": "Idle timeout per stream chunk in seconds. If no data is received for this many seconds, the stream is closed. Default: 60." }, - "request_path_overrides": { + "max_conns_per_host": { + "type": "integer", + "minimum": 1, + "maximum": 10000, + "description": "Maximum number of TCP connections per provider host. For HTTP/2 (e.g. Bedrock), each connection supports ~100 concurrent streams. Default: 5000." + }, + "beta_header_overrides": { "type": "object", - "description": "Mapping of request type to custom path overriding the default provider path", "additionalProperties": { - "type": "string" - } + "type": "boolean" + }, + "description": "Override default Anthropic beta header support per provider. Keys are header prefixes (e.g. 'redact-thinking-'), values are true (supported) or false (unsupported). Headers not listed use the built-in defaults." } }, - "required": ["base_provider_type"], "additionalProperties": false }, - "network_config": { + "network_config_without_base_url": { "type": "object", "properties": { - "base_url": { - "type": "string", - "format": "uri", - "description": "Base URL for the provider (optional, required for Ollama)" - }, "extra_headers": { "type": "object", "additionalProperties": { @@ -2158,8 +2135,7 @@ "items": { "type": "string" }, - "default": [], - "description": "Supported models for this key" + "description": "Models this key can access. Use [\"*\"] to allow all models; empty array denies all (deny-by-default)." }, "weight": { "type": "number", @@ -2170,6 +2146,17 @@ "type": "boolean", "description": "Whether this key can be used for batch API operations (default: false)", "default": false + }, + "aliases": { + "type": "object", + "additionalProperties": { + "type": "string", + "minLength": 1 + }, + "propertyNames": { + "minLength": 1 + }, + "description": "Model alias mappings: maps a model name to a provider-specific identifier (deployment name, inference profile ID, fine-tuned model ID, etc.)" } }, "required": [ @@ -2305,6 +2292,86 @@ } ] }, + "replicate_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "replicate_key_config": { + "type": "object", + "properties": { + "use_deployments_endpoint": { + "type": "boolean", + "description": "Whether to use the deployments endpoint instead of the models endpoint (default: false)" + } + }, + "additionalProperties": false + } + } + } + ] + }, + "ollama_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "ollama_key_config": { + "type": "object", + "properties": { + "url": { + "type": "string", + "minLength": 1, + "description": "Ollama server base URL (can use env. prefix)" + } + }, + "required": [ + "url" + ], + "additionalProperties": false + } + }, + "required": [ + "ollama_key_config" + ] + } + ] + }, + "sgl_key": { + "allOf": [ + { + "$ref": "#/$defs/base_key" + }, + { + "type": "object", + "properties": { + "sgl_key_config": { + "type": "object", + "properties": { + "url": { + "type": "string", + "minLength": 1, + "description": "SGLang server base URL (can use env. prefix)" + } + }, + "required": [ + "url" + ], + "additionalProperties": false + } + }, + "required": [ + "sgl_key_config" + ] + } + ] + }, "azure_key": { "allOf": [ { @@ -2320,13 +2387,6 @@ "type": "string", "description": "Azure endpoint (can use env. prefix)" }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "Model to deployment mappings" - }, "api_version": { "type": "string", "description": "Azure API version" @@ -2371,13 +2431,6 @@ "auth_credentials": { "type": "string", "description": "Authentication credentials (can use env. prefix)" - }, - "deployments": { - "type": "object", - "additionalProperties": { - "type": "string" - }, - "description": "Model to deployment mappings" } }, "required": [ @@ -2427,13 +2480,6 @@ }, "custom_provider_config": { "$ref": "#/$defs/custom_provider_config" - }, - "pricing_overrides": { - "type": "array", - "items": { - "$ref": "#/$defs/provider_pricing_override" - }, - "description": "Provider-level pricing overrides matched by model pattern" } }, "required": [ @@ -2475,13 +2521,6 @@ }, "custom_provider_config": { "$ref": "#/$defs/custom_provider_config" - }, - "pricing_overrides": { - "type": "array", - "items": { - "$ref": "#/$defs/provider_pricing_override" - }, - "description": "Provider-level pricing overrides matched by model pattern" } }, "required": [ @@ -2501,7 +2540,7 @@ "description": "API keys for this provider" }, "network_config": { - "$ref": "#/$defs/network_config" + "$ref": "#/$defs/network_config_without_base_url" }, "concurrency_and_buffer_size": { "$ref": "#/$defs/concurrency_config" @@ -2523,13 +2562,47 @@ }, "custom_provider_config": { "$ref": "#/$defs/custom_provider_config" - }, - "pricing_overrides": { + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_replicate_config": { + "type": "object", + "properties": { + "keys": { "type": "array", "items": { - "$ref": "#/$defs/provider_pricing_override" + "$ref": "#/$defs/replicate_key" }, - "description": "Provider-level pricing overrides matched by model pattern" + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config_without_base_url" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "custom_provider_config": { + "$ref": "#/$defs/custom_provider_config" } }, "required": [ @@ -2537,13 +2610,13 @@ ], "additionalProperties": false }, - "provider_with_azure_config": { + "provider_with_replicate_config": { "type": "object", "properties": { "keys": { "type": "array", "items": { - "$ref": "#/$defs/azure_key" + "$ref": "#/$defs/replicate_key" }, "minItems": 1, "description": "API keys for this provider" @@ -2571,13 +2644,47 @@ }, "custom_provider_config": { "$ref": "#/$defs/custom_provider_config" - }, - "pricing_overrides": { + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_azure_config": { + "type": "object", + "properties": { + "keys": { "type": "array", "items": { - "$ref": "#/$defs/provider_pricing_override" + "$ref": "#/$defs/azure_key" }, - "description": "Provider-level pricing overrides matched by model pattern" + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "custom_provider_config": { + "$ref": "#/$defs/custom_provider_config" } }, "required": [ @@ -2619,13 +2726,88 @@ }, "custom_provider_config": { "$ref": "#/$defs/custom_provider_config" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_ollama_config": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "$ref": "#/$defs/ollama_key" + }, + "minItems": 1, + "description": "API keys for this provider" }, - "pricing_overrides": { + "network_config": { + "$ref": "#/$defs/network_config_without_base_url" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "custom_provider_config": { + "$ref": "#/$defs/custom_provider_config" + } + }, + "required": [ + "keys" + ], + "additionalProperties": false + }, + "provider_with_sgl_config": { + "type": "object", + "properties": { + "keys": { "type": "array", "items": { - "$ref": "#/$defs/provider_pricing_override" + "$ref": "#/$defs/sgl_key" }, - "description": "Provider-level pricing overrides matched by model pattern" + "minItems": 1, + "description": "API keys for this provider" + }, + "network_config": { + "$ref": "#/$defs/network_config_without_base_url" + }, + "concurrency_and_buffer_size": { + "$ref": "#/$defs/concurrency_config" + }, + "proxy_config": { + "$ref": "#/$defs/proxy_config" + }, + "send_back_raw_request": { + "type": "boolean", + "description": "Include raw request in BifrostResponse (default: false)" + }, + "send_back_raw_response": { + "type": "boolean", + "description": "Include raw response in BifrostResponse (default: false)" + }, + "store_raw_request_response": { + "type": "boolean", + "description": "Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)" + }, + "custom_provider_config": { + "$ref": "#/$defs/custom_provider_config" } }, "required": [ @@ -2664,12 +2846,17 @@ }, "auth_type": { "type": "string", - "enum": ["none", "headers", "oauth"], + "enum": [ + "none", + "headers", + "oauth", + "per_user_oauth" + ], "description": "Authentication type for MCP connection" }, "oauth_config_id": { "type": "string", - "description": "OAuth config ID reference (for oauth auth type)" + "description": "OAuth config ID reference (required when auth_type is 'oauth' or 'per_user_oauth')" }, "headers": { "type": "object", @@ -2751,6 +2938,13 @@ "type": "string", "description": "Per-client override for tool sync interval (Go duration, e.g. '10m', '1h', 0 = use global, negative = disabled)" }, + "allowed_extra_headers": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Allowlist of request-level headers that callers may forward to this MCP server at execution time. Use ['*'] to allow all headers." + }, "is_ping_available": { "type": "boolean", "description": "Whether the MCP server supports ping for health checks (default: true)", @@ -2763,6 +2957,11 @@ "type": "number", "minimum": 0 } + }, + "allow_on_all_virtual_keys": { + "type": "boolean", + "description": "When true, this MCP server is accessible to all virtual keys without requiring explicit per-key assignment. All tools are allowed by default. If a virtual key has an explicit MCP config for this server, that config takes precedence and overrides this behaviour.", + "default": false } }, "required": [ @@ -2770,6 +2969,17 @@ "connection_type" ], "additionalProperties": false, + "if": { + "properties": { + "auth_type": { + "enum": ["oauth", "per_user_oauth"] + } + }, + "required": ["auth_type"] + }, + "then": { + "required": ["oauth_config_id"] + }, "oneOf": [ { "properties": { @@ -2798,8 +3008,16 @@ } }, "anyOf": [ - { "required": ["http_config"] }, - { "required": ["connection_string"] } + { + "required": [ + "http_config" + ] + }, + { + "required": [ + "connection_string" + ] + } ] }, { @@ -2831,8 +3049,16 @@ }, "code_mode_binding_level": { "type": "string", - "enum": ["server", "tool"], + "enum": [ + "server", + "tool" + ], "description": "How tools are exposed in VFS for code execution" + }, + "disable_auto_tool_inject": { + "type": "boolean", + "description": "When true, MCP tools are not automatically injected into requests. Tools are only included when explicitly specified via request context filters or headers, such as x-bf-mcp-include-tools or x-bf-mcp-include-clients.", + "default": false } } }, @@ -3324,7 +3550,11 @@ }, "cloud": { "type": "string", - "enum": ["commercial", "gcc-high", "dod"], + "enum": [ + "commercial", + "gcc-high", + "dod" + ], "default": "commercial", "description": "Cloud environment: 'commercial' (default), 'gcc-high' for US Government GCC High, or 'dod' for Department of Defense" }, @@ -3619,6 +3849,284 @@ } }, "additionalProperties": false + }, + "pricing_override": { + "type": "object", + "description": "Scoped pricing override applied at runtime by the model catalog", + "properties": { + "id": { + "type": "string", + "description": "Unique pricing override ID" + }, + "name": { + "type": "string", + "description": "Human-readable name for this override" + }, + "scope_kind": { + "type": "string", + "description": "Scope level for this override", + "enum": [ + "global", + "provider", + "provider_key", + "virtual_key", + "virtual_key_provider", + "virtual_key_provider_key" + ] + }, + "virtual_key_id": { + "type": "string", + "description": "Virtual key ID (required for virtual_key* scopes)" + }, + "provider_id": { + "type": "string", + "description": "Provider ID (required for provider* scopes)" + }, + "provider_key_id": { + "type": "string", + "description": "Provider key ID (required for provider_key and virtual_key_provider_key scopes)" + }, + "match_type": { + "type": "string", + "description": "How the pattern is matched against model names", + "enum": [ + "exact", + "wildcard" + ] + }, + "pattern": { + "type": "string", + "description": "Model name pattern to match (exact name or wildcard prefix ending with *)" + }, + "request_types": { + "type": "array", + "description": "Request types this override applies to. At least one value is required.", + "minItems": 1, + "items": { + "type": "string" + } + }, + "pricing_patch": { + "type": "string", + "description": "JSON-encoded pricing fields to override (e.g. '{\"input_cost_per_token\":0.000001}')" + }, + "config_hash": { + "type": "string", + "description": "Internal hash for change detection (auto-managed)" + } + }, + "required": [ + "id", + "name", + "scope_kind", + "match_type", + "pattern", + "request_types" + ], + "additionalProperties": false + }, + "pricing_override_match_type": { + "type": "string", + "enum": [ + "exact", + "wildcard" + ] + }, + "pricing_override_request_type": { + "type": "string", + "enum": [ + "chat_completion", + "text_completion", + "responses", + "embedding", + "rerank", + "speech", + "transcription", + "image_generation", + "image_variation", + "image_edit", + "video_generation", + "video_remix" + ] + }, + "custom_provider_config": { + "type": "object", + "description": "Custom provider configuration for extending or customizing provider behavior", + "properties": { + "is_key_less": { + "type": "boolean", + "description": "Whether the custom provider requires a key" + }, + "base_provider_type": { + "type": "string", + "description": "Base provider type to extend" + }, + "allowed_requests": { + "type": "object", + "description": "Allowed request types for the custom provider", + "properties": { + "list_models": { + "type": "boolean" + }, + "text_completion": { + "type": "boolean" + }, + "text_completion_stream": { + "type": "boolean" + }, + "chat_completion": { + "type": "boolean" + }, + "chat_completion_stream": { + "type": "boolean" + }, + "responses": { + "type": "boolean" + }, + "responses_stream": { + "type": "boolean" + }, + "count_tokens": { + "type": "boolean" + }, + "embedding": { + "type": "boolean" + }, + "rerank": { + "type": "boolean" + }, + "speech": { + "type": "boolean" + }, + "speech_stream": { + "type": "boolean" + }, + "transcription": { + "type": "boolean" + }, + "transcription_stream": { + "type": "boolean" + }, + "image_generation": { + "type": "boolean" + }, + "image_generation_stream": { + "type": "boolean" + }, + "image_edit": { + "type": "boolean" + }, + "image_edit_stream": { + "type": "boolean" + }, + "image_variation": { + "type": "boolean" + }, + "video_generation": { + "type": "boolean" + }, + "video_retrieve": { + "type": "boolean" + }, + "video_download": { + "type": "boolean" + }, + "video_delete": { + "type": "boolean" + }, + "video_list": { + "type": "boolean" + }, + "video_remix": { + "type": "boolean" + }, + "batch_create": { + "type": "boolean" + }, + "batch_list": { + "type": "boolean" + }, + "batch_retrieve": { + "type": "boolean" + }, + "batch_cancel": { + "type": "boolean" + }, + "batch_delete": { + "type": "boolean" + }, + "batch_results": { + "type": "boolean" + }, + "file_upload": { + "type": "boolean" + }, + "file_list": { + "type": "boolean" + }, + "file_retrieve": { + "type": "boolean" + }, + "file_delete": { + "type": "boolean" + }, + "file_content": { + "type": "boolean" + }, + "container_create": { + "type": "boolean" + }, + "container_list": { + "type": "boolean" + }, + "container_retrieve": { + "type": "boolean" + }, + "container_delete": { + "type": "boolean" + }, + "container_file_create": { + "type": "boolean" + }, + "container_file_list": { + "type": "boolean" + }, + "container_file_retrieve": { + "type": "boolean" + }, + "container_file_content": { + "type": "boolean" + }, + "container_file_delete": { + "type": "boolean" + }, + "passthrough": { + "type": "boolean" + }, + "passthrough_stream": { + "type": "boolean" + }, + "websocket_responses": { + "type": "boolean" + }, + "realtime": { + "type": "boolean" + } + }, + "additionalProperties": false + }, + "request_path_overrides": { + "type": "object", + "description": "Mapping of request type to custom path overriding the default provider path", + "additionalProperties": { + "type": "string" + } + } + }, + "required": [ + "base_provider_type" + ], + "additionalProperties": false } } -} +} \ No newline at end of file diff --git a/transports/go.mod b/transports/go.mod index 6486feccae..6750928817 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -12,15 +12,18 @@ require ( github.com/google/uuid v1.6.0 github.com/klauspost/compress v1.18.2 github.com/mark3labs/mcp-go v0.43.2 - github.com/maximhq/bifrost/core v1.4.17 - github.com/maximhq/bifrost/framework v1.2.36 - github.com/maximhq/bifrost/plugins/governance v1.4.36 - github.com/maximhq/bifrost/plugins/litellmcompat v0.0.25 - github.com/maximhq/bifrost/plugins/logging v1.4.36 - github.com/maximhq/bifrost/plugins/maxim v1.5.36 - github.com/maximhq/bifrost/plugins/otel v1.1.35 - github.com/maximhq/bifrost/plugins/semanticcache v1.4.34 - github.com/maximhq/bifrost/plugins/telemetry v1.4.36 + github.com/maximhq/bifrost/core v1.5.0 + github.com/maximhq/bifrost/framework v1.3.0 + github.com/maximhq/bifrost/plugins/governance v1.5.0 + github.com/maximhq/bifrost/plugins/litellmcompat v0.1.0 + github.com/maximhq/bifrost/plugins/logging v1.5.0 + github.com/maximhq/bifrost/plugins/maxim v1.6.0 + github.com/maximhq/bifrost/plugins/otel v1.2.0 + github.com/maximhq/bifrost/plugins/prompts v1.0.1 + github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 + github.com/maximhq/bifrost/plugins/telemetry v1.5.0 + github.com/pion/rtcp v1.2.16 + github.com/pion/webrtc/v4 v4.2.9 github.com/prometheus/client_golang v1.23.2 github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 github.com/stretchr/testify v1.11.1 @@ -111,12 +114,26 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect - github.com/maximhq/bifrost/plugins/mocker v1.4.35 // indirect + github.com/maximhq/bifrost/plugins/mocker v1.5.0 // indirect github.com/maximhq/maxim-go v0.2.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/oapi-codegen/runtime v1.1.1 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pinecone-io/go-pinecone/v5 v5.3.0 // indirect + github.com/pion/datachannel v1.6.0 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect + github.com/pion/ice/v4 v4.2.1 // indirect + github.com/pion/interceptor v0.1.44 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtp v1.10.1 // indirect + github.com/pion/sctp v1.9.2 // indirect + github.com/pion/sdp/v3 v3.0.18 // indirect + github.com/pion/srtp/v3 v3.0.10 // indirect + github.com/pion/stun/v3 v3.1.1 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect + github.com/pion/turn/v4 v4.1.4 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -138,6 +155,7 @@ require ( github.com/weaviate/weaviate v1.36.5 // indirect github.com/weaviate/weaviate-go-client/v5 v5.7.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/wlynxg/anet v0.0.5 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.mongodb.org/mongo-driver v1.17.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect @@ -159,6 +177,7 @@ require ( golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect + golang.org/x/time v0.14.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect google.golang.org/grpc v1.79.3 // indirect diff --git a/transports/go.sum b/transports/go.sum index d3f9df0c88..cf68e1a7ec 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -213,26 +213,28 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.4.17 h1:jI3tM3e6szXMKx3CuGH/Z5ks2GpRMS13r6QuITJb9z0= -github.com/maximhq/bifrost/core v1.4.17/go.mod h1:O6VEP2MHkQgo1iLYoxGQ7a+3VBBlHoETCH+pOR6Q5X8= -github.com/maximhq/bifrost/framework v1.2.36 h1:CD0/63I6J6iF5vqG68zlHEXAX9xXmHd66ZXoi83AFBs= -github.com/maximhq/bifrost/framework v1.2.36/go.mod h1:hq6UGS/Goc4wYk8sa5XEGlob8YfgsG6P/WTYsqf2smw= -github.com/maximhq/bifrost/plugins/governance v1.4.36 h1:NAnhfO+0+gHO7ym2Ir++KsrQJU0Rsgtu+p0Qpvk47oo= -github.com/maximhq/bifrost/plugins/governance v1.4.36/go.mod h1:bCTXEGOe5JLa3tstfvxN/2nD1tjspaIoGmNQO7sSdqI= -github.com/maximhq/bifrost/plugins/litellmcompat v0.0.25 h1:xiFFBq6bnUr83URi9mCwsQ62nT1a7A0QC94IEu2JHos= -github.com/maximhq/bifrost/plugins/litellmcompat v0.0.25/go.mod h1:v4jTiTHb7R4obYsKh5uaSHVBNExge59wCE9EqLuCG1c= -github.com/maximhq/bifrost/plugins/logging v1.4.36 h1:ADBwS9QlFpYnGdXOqItir2H1FSEpVBoWN9ikHPMEEvQ= -github.com/maximhq/bifrost/plugins/logging v1.4.36/go.mod h1:is7qqmA//L1VwKYZ6OHXQgWewJi6VD2LFYSiXabbxlA= -github.com/maximhq/bifrost/plugins/maxim v1.5.36 h1:AmqoWY0XBGhscpVTC2QtAqnK9d1EeMqXhxQQ5uGKip4= -github.com/maximhq/bifrost/plugins/maxim v1.5.36/go.mod h1:kcIQQyeE0RMU//G3beH4UWf8sW+MJXppdgCfKy2EgzY= -github.com/maximhq/bifrost/plugins/mocker v1.4.35 h1:p9gygaMnfeS/2hSUK7VNydyZeq2EOwxMtzFjw0uyKZs= -github.com/maximhq/bifrost/plugins/mocker v1.4.35/go.mod h1:ZlbM1iDMA04ijkPENyFzhLX9Gf/LOvjI5Bd9MFNl+lQ= -github.com/maximhq/bifrost/plugins/otel v1.1.35 h1:/0eTtbFPHj3SqM+gaABdUy/sbJyCYn45Bm1pLKMpyio= -github.com/maximhq/bifrost/plugins/otel v1.1.35/go.mod h1:nfi8v40d5jO3pIEqIF6bTYpkRmGYiz5MRqfDdvYxVWs= -github.com/maximhq/bifrost/plugins/semanticcache v1.4.34 h1:kJkoiLhZas4GH2CciVAszHQfm4dnhAMsW16Vp9v5XUw= -github.com/maximhq/bifrost/plugins/semanticcache v1.4.34/go.mod h1:LmhoVafiyNBdUPd12k64S6I/lpkUIoG/O3YSwGRDn1I= -github.com/maximhq/bifrost/plugins/telemetry v1.4.36 h1:8iFQgfwtL4GeybuD2g5lEVOfnDBJ4j7+eSsqpRv5gqc= -github.com/maximhq/bifrost/plugins/telemetry v1.4.36/go.mod h1:MfrOTGXzFv6qvCWs9clXAW4l1qJR6+0DDkVy8P1GiQs= +github.com/maximhq/bifrost/core v1.5.0 h1:COg/4ssyANLeYt3VbfoU2FdgEDLcpSPpqEnvl5238AA= +github.com/maximhq/bifrost/core v1.5.0/go.mod h1:A+AHUm/jf2lWFz5RNSxcJD/ozPlFJIVK9riMM1nyjt8= +github.com/maximhq/bifrost/framework v1.3.0 h1:TRUKCM39qgJw0MrvfFPhY6UEdcgTGlxZ0zrT02ScaXw= +github.com/maximhq/bifrost/framework v1.3.0/go.mod h1:mDCR8IRMaHFffTJxyaYf9/7grG8knskluachivWjRAA= +github.com/maximhq/bifrost/plugins/governance v1.5.0 h1:cT+QiIKqJNKjl6/q0W3HTuZSeql0MHx3UWTyZPMLag4= +github.com/maximhq/bifrost/plugins/governance v1.5.0/go.mod h1:hjC5TmTdk4bES89zPUwBTwWWteHNtTV8WytdkPZUWd8= +github.com/maximhq/bifrost/plugins/litellmcompat v0.1.0 h1:asVMw3YanOeKpCj2DP8byg2cNAZe4j/91jTrXc5O56s= +github.com/maximhq/bifrost/plugins/litellmcompat v0.1.0/go.mod h1:+V/OxIyKWUhiAV3HdmvwbXI8gHp/VhVrBoipYDUzdlk= +github.com/maximhq/bifrost/plugins/logging v1.5.0 h1:uGrernx8gENT84L7fXyEpgvJZgORsGZvyq5B4PkSj80= +github.com/maximhq/bifrost/plugins/logging v1.5.0/go.mod h1:uxdMIVHUG7u5Wc5HQzXY13UlExc3lDumRgC8M+kTQiw= +github.com/maximhq/bifrost/plugins/maxim v1.6.0 h1:F23T1qcMczcuauGCYO5p9qeZOAc48FPjFdaSK9TmVeY= +github.com/maximhq/bifrost/plugins/maxim v1.6.0/go.mod h1:V/ccWAfBiW6kVXGWLe9tXKoTgFSh9sYgaJRrtEwFTso= +github.com/maximhq/bifrost/plugins/mocker v1.5.0 h1:mZ2oZNOnISG6wdhAwkxwplSl0QBPPLFk+IJYEdsi8XU= +github.com/maximhq/bifrost/plugins/mocker v1.5.0/go.mod h1:ewREvrpRbIyMwnm50MkPdMbr2a4rRwxlAiiCk6YxgYA= +github.com/maximhq/bifrost/plugins/otel v1.2.0 h1:+aJnWdryDlhza7wc4KETosX9j3Mdad5uUFBuwhslNsk= +github.com/maximhq/bifrost/plugins/otel v1.2.0/go.mod h1:BwNVvRuEgdPlSlDLzANpGy2RugWQjtHkEUoBiwT5MNI= +github.com/maximhq/bifrost/plugins/prompts v1.0.1 h1:JpM+uVkYmNLWEvg/hT8HN2Wpzax6TUsM/mdIyYzkx00= +github.com/maximhq/bifrost/plugins/prompts v1.0.1/go.mod h1:379vljFVED/0L+odEmYQaaYDY/HFy4smb8tpXXCeBvA= +github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 h1:tibnQ8lSnKXujnjL4mt84P/5Vxj9e9wbhvh1Tjr68JA= +github.com/maximhq/bifrost/plugins/semanticcache v1.5.0/go.mod h1:+NfIRAlHpuh5ORv0MoOf5f8uY4WPx6v/8Kuk+8FEGnw= +github.com/maximhq/bifrost/plugins/telemetry v1.5.0 h1:hECZgcsqeJSmiLrWONTFFU6APzTyILQzZuVV96oql5Q= +github.com/maximhq/bifrost/plugins/telemetry v1.5.0/go.mod h1:dl/4mtQhxooqU+r42hXajhUaq04S1X3LaH+km5UJAy0= github.com/maximhq/maxim-go v0.2.1 h1:hCp8dQ4HsyyNC+y5HCUuY/HFD0sOnGkjL5MdYCHkgEQ= github.com/maximhq/maxim-go v0.2.1/go.mod h1:nwFznXy0Dn4mxXGU4X+BCnE3VP68L+FPEaW0yUgk96o= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -243,6 +245,40 @@ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pinecone-io/go-pinecone/v5 v5.3.0 h1:0YQlEtmXGWK/I8ztkOVM6PuBYgFJZhjSdb0ddU+bHPE= github.com/pinecone-io/go-pinecone/v5 v5.3.0/go.mod h1:6Fg85fcyvMUQFf9KW7zniN81kelSYvsjF+KPLdc1MGA= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= +github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= +github.com/pion/ice/v4 v4.2.1 h1:XPRYXaLiFq3LFDG7a7bMrmr3mFr27G/gtXN3v/TVfxY= +github.com/pion/ice/v4 v4.2.1/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c= +github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I= +github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= +github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= +github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA= +github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM= +github.com/pion/sctp v1.9.2 h1:HxsOzEV9pWoeggv7T5kewVkstFNcGvhMPx0GvUOUQXo= +github.com/pion/sctp v1.9.2/go.mod h1:OTOlsQ5EDQ6mQ0z4MUGXt2CgQmKyafBEXhUVqLRB6G8= +github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI= +github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8= +github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ= +github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M= +github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw= +github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ= +github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ= +github.com/pion/webrtc/v4 v4.2.9 h1:DZIh1HAhPIL3RvwEDFsmL5hfPSLEpxsQk9/Jir2vkJE= +github.com/pion/webrtc/v4 v4.2.9/go.mod h1:9EmLZve0H76eTzf8v2FmchZ6tcBXtDgpfTEu+drW6SY= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -314,6 +350,8 @@ github.com/weaviate/weaviate-go-client/v5 v5.7.1 h1:vEMxh486QqRqWaq58UEe/TiTbGbo github.com/weaviate/weaviate-go-client/v5 v5.7.1/go.mod h1:T/JDErjN074GrnYIa0AgK1TGUGP/6A/8vqXNPlv4c6E= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -368,6 +406,8 @@ golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 h1:7ei4lp52gK1uSejlA8AZl5AJjeLUOHBQscRQZUgAcu0= diff --git a/transports/schema_test/config_schema_test.go b/transports/schema_test/config_schema_test.go index 7ad566c555..8891c2ade7 100644 --- a/transports/schema_test/config_schema_test.go +++ b/transports/schema_test/config_schema_test.go @@ -165,21 +165,13 @@ func validateConfig(t *testing.T, schema *jsonschema.Schema, configJSON string) return schema.Validate(v) } -func TestSchemaVertexKeyDeployments(t *testing.T) { - schemaPath := getSchemaPath(t) - data, err := os.ReadFile(schemaPath) - if err != nil { - t.Fatalf("failed to read schema: %v", err) - } - var schema map[string]interface{} - if err := json.Unmarshal(data, &schema); err != nil { - t.Fatalf("failed to parse schema: %v", err) - } +func TestSchemaKeyAliases(t *testing.T) { + schema := loadSchema(t) - t.Run("vertex_key $def includes deployments field", func(t *testing.T) { - _, found := navigateJSON(schema, "$defs", "vertex_key", "allOf", 1, "properties", "vertex_key_config", "properties", "deployments") + t.Run("base_key $def includes aliases field", func(t *testing.T) { + _, found := navigateJSON(schema, "$defs", "base_key", "properties", "aliases") if !found { - t.Error("$defs/vertex_key is missing 'deployments' property β€” vertex provider uses getModelDeployment() on every request") + t.Error("$defs/base_key is missing 'aliases' property β€” aliases replaced per-provider deployments maps") } }) @@ -190,30 +182,60 @@ func TestSchemaVertexKeyDeployments(t *testing.T) { } }) - t.Run("vertex config with deployments validates successfully", func(t *testing.T) { + t.Run("vertex_key_config does not include deployments field", func(t *testing.T) { + _, found := navigateJSON(schema, "$defs", "vertex_key", "allOf", 1, "properties", "vertex_key_config", "properties", "deployments") + if found { + t.Error("$defs/vertex_key still has 'deployments' in vertex_key_config β€” deployments were moved to top-level key aliases") + } + }) + + t.Run("key with aliases validates successfully", func(t *testing.T) { compiled := compileSchema(t) config := `{ "providers": { "vertex": { "keys": [{ - "key_id": "test", "name": "test", "value": "", "weight": 1, "models": ["gemini-2.0-flash"], + "aliases": {"gemini-2.0-flash": "gemini-2.0-flash-001"}, "vertex_key_config": { "project_id": "my-project", "region": "us-central1", "auth_credentials": "", - "project_number": "123456", - "deployments": {"gemini-2.0-flash": "gemini-2.0-flash-001"} + "project_number": "123456" + } + }] + } + } + }` + if err := validateConfig(t, compiled, config); err != nil { + t.Errorf("key with aliases should be valid, got: %v", err) + } + }) + + t.Run("azure key with aliases validates successfully", func(t *testing.T) { + compiled := compileSchema(t) + config := `{ + "providers": { + "azure": { + "keys": [{ + "name": "test", + "value": "my-api-key", + "weight": 1, + "models": ["gpt-4o"], + "aliases": {"gpt-4o": "gpt-4o-deployment"}, + "azure_key_config": { + "endpoint": "https://my-resource.openai.azure.com", + "api_version": "2024-02-01" } }] } } }` if err := validateConfig(t, compiled, config); err != nil { - t.Errorf("vertex config with deployments should be valid, got: %v", err) + t.Errorf("azure key with aliases should be valid, got: %v", err) } }) } @@ -272,6 +294,7 @@ func TestSchemaClientMCPFields(t *testing.T) { "mcp_tool_execution_timeout", "mcp_code_mode_binding_level", "mcp_tool_sync_interval", + "mcp_disable_auto_tool_inject", } for _, field := range fields { t.Run("client has "+field, func(t *testing.T) { @@ -290,7 +313,8 @@ func TestSchemaClientMCPFields(t *testing.T) { "mcp_agent_depth": 5, "mcp_tool_execution_timeout": 60, "mcp_code_mode_binding_level": "server", - "mcp_tool_sync_interval": 10 + "mcp_tool_sync_interval": 10, + "mcp_disable_auto_tool_inject": false } }` if err := validateConfig(t, compiled, config); err != nil { diff --git a/transports/version b/transports/version index 543bc68a9a..a03dd36205 100644 --- a/transports/version +++ b/transports/version @@ -1 +1 @@ -1.4.20 \ No newline at end of file +1.5.0-prerelease2 \ No newline at end of file diff --git a/ui/app/_fallbacks/enterprise/components/access-profiles/accessProfilesIndexView.tsx b/ui/app/_fallbacks/enterprise/components/access-profiles/accessProfilesIndexView.tsx new file mode 100644 index 0000000000..e7747a2742 --- /dev/null +++ b/ui/app/_fallbacks/enterprise/components/access-profiles/accessProfilesIndexView.tsx @@ -0,0 +1,17 @@ +import { ShieldCheck } from "lucide-react"; +import ContactUsView from "../views/contactUsView"; + +export default function AccessProfilesIndexView() { + return ( +
+ } + title="Unlock access profiles for better performance" + description="This feature is a part of the Bifrost enterprise license. Create access profiles to control access to your resources." + readmeLink="https://docs.getbifrost.ai/enterprise/access-profiles" + testIdPrefix="access-profiles" + /> +
+ ); +} diff --git a/ui/app/_fallbacks/enterprise/components/user-groups/businessUnitsView.tsx b/ui/app/_fallbacks/enterprise/components/user-groups/businessUnitsView.tsx new file mode 100644 index 0000000000..2a74a92e3a --- /dev/null +++ b/ui/app/_fallbacks/enterprise/components/user-groups/businessUnitsView.tsx @@ -0,0 +1,17 @@ +import { Building2, Users } from "lucide-react"; +import ContactUsView from "../views/contactUsView"; + +export function BusinessUnitsView() { + return ( +
+ } + title="Unlock business units & advanced governance" + description="Manage users, business units with our enterprise-grade governance. This feature is part of the Bifrost enterprise license." + readmeLink="https://docs.getbifrost.ai/enterprise/advanced-governance" + /> +
+ ); +} diff --git a/ui/app/_fallbacks/enterprise/components/user-groups/teamsView.tsx b/ui/app/_fallbacks/enterprise/components/user-groups/teamsView.tsx new file mode 100644 index 0000000000..ecf9fa7fb4 --- /dev/null +++ b/ui/app/_fallbacks/enterprise/components/user-groups/teamsView.tsx @@ -0,0 +1,17 @@ +import { Users } from "lucide-react"; +import ContactUsView from "../views/contactUsView"; + +export function TeamsView() { + return ( +
+ } + title="Unlock teams governance" + description="Manage teams, sync from your identity provider, and control access with enterprise-grade governance. This feature is part of the Bifrost enterprise license." + readmeLink="https://docs.getbifrost.ai/enterprise/advanced-governance" + /> +
+ ) +} diff --git a/ui/app/_fallbacks/enterprise/components/user-rankings/userRankingsTab.tsx b/ui/app/_fallbacks/enterprise/components/user-rankings/userRankingsTab.tsx new file mode 100644 index 0000000000..d1cca8c2e8 --- /dev/null +++ b/ui/app/_fallbacks/enterprise/components/user-rankings/userRankingsTab.tsx @@ -0,0 +1,17 @@ +import { Users } from "lucide-react"; +import ContactUsView from "../views/contactUsView"; + +export default function UserRankingsTab() { + return ( +
+ } + title="Unlock user rankings for better visibility" + description="This feature is a part of the Bifrost enterprise license. We would love to know more about your use case and how we can help you." + readmeLink="https://docs.getbifrost.ai/enterprise/user-rankings" + testIdPrefix="user-rankings" + /> +
+ ); +} diff --git a/ui/app/_fallbacks/enterprise/lib/contexts/rbacContext.tsx b/ui/app/_fallbacks/enterprise/lib/contexts/rbacContext.tsx index 36f895a281..0f00037bf7 100644 --- a/ui/app/_fallbacks/enterprise/lib/contexts/rbacContext.tsx +++ b/ui/app/_fallbacks/enterprise/lib/contexts/rbacContext.tsx @@ -27,6 +27,7 @@ export enum RbacResource { PIIRedactor = "PIIRedactor", PromptRepository = "PromptRepository", PromptDeploymentStrategy = "PromptDeploymentStrategy", + AccessProfiles = "AccessProfiles", } // RBAC Operation Names (must match backend definitions) diff --git a/ui/app/globals.css b/ui/app/globals.css index 1402e40487..131374f5d2 100644 --- a/ui/app/globals.css +++ b/ui/app/globals.css @@ -226,19 +226,91 @@ body { } div.content-container:has(.no-padding-parent) { - @apply !p-0; + @apply p-0!; } div.content-container main.content-container-inner:has(.no-padding-parent) { - @apply !p-0; + @apply p-0!; } div.content-container:has(.no-border-parent) { - @apply !border-0; + @apply border-0!; +} + +/* ReactFlow Controls β€” follow Bifrost colour schema */ + +.react-flow__controls { + background-color: var(--card); + border: 1px solid var(--border); + border-radius: var(--radius); + box-shadow: 0 1px 3px 0 rgb(0 0 0 / 0.1); +} + +.react-flow__controls-button { + background-color: var(--card); + border-bottom: 1px solid var(--border); + fill: var(--foreground); +} + +.react-flow__controls-button:hover { + background-color: var(--muted); +} + +.react-flow__controls-button svg { + fill: var(--foreground); +} + +/* Dark mode β€” needs !important to beat ReactFlow's bundled specificity */ +.dark .react-flow__controls { + background-color: var(--card) !important; + border-color: var(--border) !important; +} + +.dark .react-flow__controls-button { + background-color: var(--card) !important; + border-bottom-color: var(--border) !important; + fill: var(--foreground) !important; +} + +.dark .react-flow__controls-button:hover { + background-color: var(--muted) !important; +} + +.dark .react-flow__controls-button svg { + fill: var(--foreground) !important; +} + +/* Dynamic chain: dash period 3+5 = 8 β€” offset must move exactly one period per loop */ +@keyframes rf-routing-tree-dynamic-chain-dash { + from { + stroke-dashoffset: 0; + } + to { + stroke-dashoffset: -8; + } +} + +.rf-chain-legend-dynamic-dash { + animation: rf-routing-tree-dynamic-chain-dash 0.5s linear infinite; +} + +.react-flow__edge.rf-chain-edge-dynamic .react-flow__edge-path { + stroke-dasharray: 3 5; + animation: rf-routing-tree-dynamic-chain-dash 0.5s linear infinite; +} + +@media (prefers-reduced-motion: reduce) { + .rf-chain-legend-dynamic-dash { + animation: none; + } + + .react-flow__edge.rf-chain-edge-dynamic .react-flow__edge-path { + animation: none; + } } /* // Custom styling for streamdown */ [data-streamdown="code-block"], [data-streamdown="code-block-body"]{ - @apply !rounded-sm; + @apply rounded-sm!; } \ No newline at end of file diff --git a/ui/app/workspace/config/pricing-config/page.tsx b/ui/app/workspace/config/pricing-config/page.tsx index 7e7bd56bb1..b5d9855824 100644 --- a/ui/app/workspace/config/pricing-config/page.tsx +++ b/ui/app/workspace/config/pricing-config/page.tsx @@ -1,11 +1,11 @@ "use client" -import PricingConfigView from "../views/pricingConfigView" +import ModelSettingsView from "../views/modelSettingsView" export default function PricingConfigPage() { return (
- +
) } diff --git a/ui/app/workspace/config/views/mcpView.tsx b/ui/app/workspace/config/views/mcpView.tsx index eaf84d3bb7..3cee876c66 100644 --- a/ui/app/workspace/config/views/mcpView.tsx +++ b/ui/app/workspace/config/views/mcpView.tsx @@ -3,6 +3,7 @@ import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { Switch } from "@/components/ui/switch"; import { getErrorMessage, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; import { CoreConfig, DefaultCoreConfig } from "@/lib/types/config"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; @@ -40,13 +41,15 @@ export default function MCPView() { } }, [config, bifrostConfig]); + const hasChanges = useMemo(() => { if (!config) return false; return ( localConfig.mcp_agent_depth !== config.mcp_agent_depth || localConfig.mcp_tool_execution_timeout !== config.mcp_tool_execution_timeout || localConfig.mcp_code_mode_binding_level !== (config.mcp_code_mode_binding_level || "server") || - localConfig.mcp_tool_sync_interval !== (config.mcp_tool_sync_interval ?? 10) + localConfig.mcp_tool_sync_interval !== (config.mcp_tool_sync_interval ?? 10) || + localConfig.mcp_disable_auto_tool_inject !== (config.mcp_disable_auto_tool_inject ?? false) ); }, [config, localConfig]); @@ -81,6 +84,10 @@ export default function MCPView() { } }, []); + const handleDisableAutoToolInjectChange = useCallback((checked: boolean) => { + setLocalConfig((prev) => ({ ...prev, mcp_disable_auto_tool_inject: checked })); + }, []); + const handleSave = useCallback(async () => { try { const agentDepth = Number.parseInt(localValues.mcp_agent_depth); @@ -170,6 +177,25 @@ export default function MCPView() { /> + {/* Disable Auto Tool Injection */} +
+
+ +

+ When enabled, MCP tools are not automatically included in every request. Tools are only injected when explicitly specified via request headers (x-bf-mcp-include-tools) and still must be allowed by the virtual key MCP configuration. +

+
+ +
+ {/* Code Mode Binding Level */}
diff --git a/ui/app/workspace/config/views/modelSettingsView.tsx b/ui/app/workspace/config/views/modelSettingsView.tsx new file mode 100644 index 0000000000..fca2788439 --- /dev/null +++ b/ui/app/workspace/config/views/modelSettingsView.tsx @@ -0,0 +1,196 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { DefaultCoreConfig } from "@/lib/types/config"; +import { getErrorMessage, useForcePricingSyncMutation, useGetCoreConfigQuery, useUpdateCoreConfigMutation } from "@/lib/store"; +import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; +import { useEffect, useMemo } from "react"; +import { useForm } from "react-hook-form"; +import { toast } from "sonner"; + +interface ModelSettingsFormData { + pricing_datasheet_url: string; + pricing_sync_interval_hours: number; + routing_chain_max_depth: number; +} + +export default function ModelSettingsView() { + const hasSettingsUpdateAccess = useRbac(RbacResource.Settings, RbacOperation.Update); + const { data: bifrostConfig } = useGetCoreConfigQuery({ fromDB: true }); + const frameworkConfig = bifrostConfig?.framework_config; + const clientConfig = bifrostConfig?.client_config; + const [updateCoreConfig, { isLoading }] = useUpdateCoreConfigMutation(); + const [forcePricingSync, { isLoading: isForceSyncing }] = useForcePricingSyncMutation(); + + const { + register, + handleSubmit, + formState: { errors, isDirty }, + reset, + watch, + } = useForm({ + defaultValues: { + pricing_datasheet_url: "", + pricing_sync_interval_hours: 24, + routing_chain_max_depth: DefaultCoreConfig.routing_chain_max_depth, + }, + }); + + const formValues = watch(); + + useEffect(() => { + if (!bifrostConfig || isDirty) return; + reset({ + pricing_datasheet_url: frameworkConfig?.pricing_url || "", + pricing_sync_interval_hours: Math.round((frameworkConfig?.pricing_sync_interval ?? 0) / 3600) || 24, + routing_chain_max_depth: clientConfig?.routing_chain_max_depth ?? DefaultCoreConfig.routing_chain_max_depth, + }); + }, [ + frameworkConfig?.pricing_url, + frameworkConfig?.pricing_sync_interval, + clientConfig?.routing_chain_max_depth, + isDirty, + reset, + ]); + + const hasChanges = useMemo(() => { + if (!bifrostConfig || !isDirty) return false; + const serverUrl = frameworkConfig?.pricing_url || ""; + const serverInterval = Math.round((frameworkConfig?.pricing_sync_interval ?? 0) / 3600); + const serverDepth = clientConfig?.routing_chain_max_depth ?? DefaultCoreConfig.routing_chain_max_depth; + return ( + formValues.pricing_datasheet_url !== serverUrl || + formValues.pricing_sync_interval_hours !== serverInterval || + formValues.routing_chain_max_depth !== serverDepth + ); + }, [bifrostConfig, frameworkConfig, clientConfig, formValues, isDirty]); + + const onSubmit = async (data: ModelSettingsFormData) => { + try { + await updateCoreConfig({ + ...bifrostConfig!, + framework_config: { + ...frameworkConfig, + id: bifrostConfig?.framework_config.id || 0, + pricing_url: data.pricing_datasheet_url, + pricing_sync_interval: data.pricing_sync_interval_hours * 3600, + }, + client_config: { + ...clientConfig!, + routing_chain_max_depth: data.routing_chain_max_depth, + }, + }).unwrap(); + toast.success("Model settings updated successfully."); + reset(data); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }; + + const handleForceSync = async () => { + try { + await forcePricingSync().unwrap(); + toast.success("Pricing sync triggered successfully."); + } catch (error) { + toast.error(getErrorMessage(error)); + } + }; + + return ( +
+
+
+

Model Settings

+

Configure pricing and routing behaviour.

+
+ +
+ {/* Pricing Datasheet URL */} +
+
+ +

URL to a custom pricing datasheet. Leave empty to use default pricing.

+
+ { + if (!value) return true; + return value.startsWith("http://") || value.startsWith("https://") || "URL must start with http:// or https://"; + }, + }, + })} + className={errors.pricing_datasheet_url ? "border-destructive" : ""} + /> + {errors.pricing_datasheet_url &&

{errors.pricing_datasheet_url.message}

} +
+ + {/* Pricing Sync Interval */} +
+
+ +

How often to sync pricing data from the datasheet URL.

+
+ + {errors.pricing_sync_interval_hours && ( +

{errors.pricing_sync_interval_hours.message}

+ )} +
+ + {/* Routing Chain Max Depth */} +
+
+ +

+ Maximum number of chained routing rule evaluations per request. Prevents infinite loops from circular rule definitions. +

+
+ +
+ {errors.routing_chain_max_depth &&

{errors.routing_chain_max_depth.message}

} +
+ +
+ + +
+
+
+ ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/page.tsx b/ui/app/workspace/custom-pricing/overrides/page.tsx new file mode 100644 index 0000000000..69de04cfb7 --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/page.tsx @@ -0,0 +1,11 @@ +"use client"; + +import ScopedPricingOverridesView from "@/app/workspace/custom-pricing/overrides/scopedPricingOverridesView"; + +export default function ScopedPricingOverridesPage() { + return ( +
+ +
+ ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/pricingFieldSelector.tsx b/ui/app/workspace/custom-pricing/overrides/pricingFieldSelector.tsx new file mode 100644 index 0000000000..8552e08565 --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/pricingFieldSelector.tsx @@ -0,0 +1,234 @@ +"use client"; + +import { Badge } from "@/components/ui/badge"; +import { Input } from "@/components/ui/input"; +import { cn } from "@/lib/utils"; +import { ChevronDown, Plus, X } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import type { FieldErrors, PricingFieldKey } from "./pricingOverrideSheet"; +import { PRICING_FIELDS } from "./pricingOverrideSheet"; + +type GroupKey = "chat" | "embedding" | "rerank" | "audio" | "image" | "video"; + +const PRICING_GROUPS: { key: GroupKey; label: string }[] = [ + { key: "chat", label: "Chat / Text / Responses" }, + { key: "embedding", label: "Embedding" }, + { key: "rerank", label: "Rerank" }, + { key: "audio", label: "Audio" }, + { key: "image", label: "Image" }, + { key: "video", label: "Video" }, +]; + +const REQUEST_TYPE_TO_CATEGORY: Record = { + chat_completion: "chat", + text_completion: "chat", + responses: "chat", + embedding: "embedding", + rerank: "rerank", + speech: "audio", + transcription: "audio", + image_generation: "image", + image_variation: "image", + image_edit: "image", + video_generation: "video", + video_remix: "video", +}; + +interface PricingFieldSelectorProps { + values: Partial>; + errors: FieldErrors; + selectedRequestTypes?: string[]; + onChange: (key: PricingFieldKey, value: string) => void; + onFieldInteraction?: () => void; +} + +export function PricingFieldSelector({ values, errors, selectedRequestTypes, onChange, onFieldInteraction }: PricingFieldSelectorProps) { + const [search, setSearch] = useState(""); + const [openGroups, setOpenGroups] = useState>(new Set(["chat"])); + + const [activeFields, setActiveFields] = useState>( + () => new Set(PRICING_FIELDS.filter((f) => values[f.key] != null && values[f.key]!.trim() !== "").map((f) => f.key)), + ); + + // Sync active fields to exactly the set of keys that have non-empty values. + // This handles both loading new overrides (adds keys) and clearing the patch (removes stale keys). + useEffect(() => { + setActiveFields(new Set(PRICING_FIELDS.filter((f) => values[f.key] != null && values[f.key]!.trim() !== "").map((f) => f.key))); + }, [values]); + + // Derive active categories from selected request types + const activeCategories = useMemo | null>(() => { + if (!selectedRequestTypes || selectedRequestTypes.length === 0) return null; + const cats = new Set(); + for (const rt of selectedRequestTypes) { + const cat = REQUEST_TYPE_TO_CATEGORY[rt]; + if (cat) cats.add(cat); + } + return cats.size > 0 ? cats : null; + }, [selectedRequestTypes]); + + const trimmedSearch = search.trim().toLowerCase(); + const isSearching = trimmedSearch.length > 0; + + const filteredFields = useMemo(() => { + if (!isSearching) return null; + return PRICING_FIELDS.filter((f) => f.label.toLowerCase().includes(trimmedSearch) || f.key.toLowerCase().includes(trimmedSearch)); + }, [isSearching, trimmedSearch]); + + // Fields visible per group when not searching, respecting activeCategories filter + const visibleGroupedFields = useMemo( + () => + PRICING_GROUPS.map((group) => { + const fields = PRICING_FIELDS.filter((f) => { + if (f.group !== group.key) return false; + if (activeCategories === null) return true; + return (f.requestTypeGroups as readonly string[]).some((rg) => activeCategories.has(rg as GroupKey)); + }); + return { ...group, fields }; + }).filter((g) => g.fields.length > 0), + [activeCategories], + ); + + const toggleGroup = (key: GroupKey) => { + setOpenGroups((prev) => { + const next = new Set(prev); + if (next.has(key)) next.delete(key); + else next.add(key); + return next; + }); + }; + + const activateField = (key: PricingFieldKey) => { + setActiveFields((prev) => new Set([...prev, key])); + }; + + const deactivateField = (key: PricingFieldKey) => { + setActiveFields((prev) => { + const next = new Set(prev); + next.delete(key); + return next; + }); + onFieldInteraction?.(); + onChange(key, ""); + }; + + const handleInputChange = (key: PricingFieldKey, value: string) => { + onFieldInteraction?.(); + onChange(key, value); + }; + + const renderFieldRow = (field: { key: PricingFieldKey; label: string }) => { + const isActive = activeFields.has(field.key); + const hasValue = values[field.key]?.trim(); + const error = errors[field.key]; + + if (!isActive) { + return ( + + ); + } + + return ( +
+
+ {field.label} + +
+ handleInputChange(field.key, e.target.value)} + placeholder="0.0" + /> + {error &&

{error}

} +
+ ); + }; + + return ( +
+ setSearch(e.target.value)} + className="h-9" + data-testid="pricing-field-search" + /> + +
+ {isSearching ? ( +
+ {filteredFields!.length === 0 ? ( +
No fields match “{search}”
+ ) : ( + filteredFields!.map((field) => renderFieldRow(field)) + )} +
+ ) : ( +
+ {visibleGroupedFields.length === 0 ? ( +
No pricing fields for the selected request types
+ ) : ( + visibleGroupedFields.map((group) => { + const isOpen = openGroups.has(group.key); + const valueCount = group.fields.filter((f) => values[f.key]?.trim()).length; + + return ( +
+ + + {isOpen && ( +
+ {group.fields.map((field) => renderFieldRow(field))} +
+ )} +
+ ); + }) + )} +
+ )} +
+
+ ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/pricingOverrideSheet.tsx b/ui/app/workspace/custom-pricing/overrides/pricingOverrideSheet.tsx new file mode 100644 index 0000000000..303232d36f --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/pricingOverrideSheet.tsx @@ -0,0 +1,1000 @@ +"use client"; + +import { CodeEditor } from "@/components/ui/codeEditor"; +import { Button } from "@/components/ui/button"; +import { Checkbox } from "@/components/ui/checkbox"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { PricingFieldSelector } from "./pricingFieldSelector"; +import { + getErrorMessage, + useCreatePricingOverrideMutation, + useGetProvidersQuery, + useGetVirtualKeysQuery, + useUpdatePricingOverrideMutation, +} from "@/lib/store"; +import { useGetAllKeysQuery } from "@/lib/store/apis/providersApi"; +import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; +import { getProviderLabel, RequestTypeLabels } from "@/lib/constants/logs"; +import { ModelProvider, RequestType } from "@/lib/types/config"; +import { + CreatePricingOverrideRequest, + PricingOverride, + PricingOverrideMatchType, + PricingOverridePatch, + PricingOverrideScopeKind, +} from "@/lib/types/governance"; +import { cn } from "@/lib/utils"; +import { ChevronDown, Save, X } from "lucide-react"; +import { Dispatch, SetStateAction, useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { toast } from "sonner"; + +export const REQUEST_TYPE_GROUPS = [ + { + label: "Chat / Text / Responses", + types: ["chat_completion", "text_completion", "responses"], + }, + { + label: "Embedding", + types: ["embedding"], + }, + { + label: "Rerank", + types: ["rerank"], + }, + { + label: "Audio", + types: ["speech", "transcription"], + }, + { + label: "Image", + types: ["image_generation", "image_variation", "image_edit"], + }, + { + label: "Video", + types: ["video_generation", "video_remix"], + }, +] as const; + +export const REQUEST_TYPE_OPTIONS = REQUEST_TYPE_GROUPS.flatMap((g) => g.types); + +export function getRequestTypeGroup(rt: string): string | undefined { + return REQUEST_TYPE_GROUPS.find((g) => (g.types as readonly string[]).includes(rt))?.label; +} + +export const PRICING_FIELDS = [ + // Chat / Text / Responses fields + { + key: "input_cost_per_token", + label: "Input / token", + group: "chat", + requestTypeGroups: ["chat", "embedding", "rerank", "audio", "image", "video"], + }, + { + key: "output_cost_per_token", + label: "Output / token", + group: "chat", + requestTypeGroups: ["chat", "rerank", "audio", "image", "video"], + }, + { key: "input_cost_per_token_batches", label: "Input / token (batch)", group: "chat", requestTypeGroups: ["chat"] }, + { key: "output_cost_per_token_batches", label: "Output / token (batch)", group: "chat", requestTypeGroups: ["chat"] }, + { key: "input_cost_per_token_priority", label: "Input / token (priority)", group: "chat", requestTypeGroups: ["chat"] }, + { key: "output_cost_per_token_priority", label: "Output / token (priority)", group: "chat", requestTypeGroups: ["chat"] }, + { + key: "input_cost_per_token_above_128k_tokens", + label: "Input / token (>128k)", + group: "chat", + requestTypeGroups: ["chat", "embedding", "rerank"], + }, + { + key: "output_cost_per_token_above_128k_tokens", + label: "Output / token (>128k)", + group: "chat", + requestTypeGroups: ["chat", "rerank", "audio"], + }, + { + key: "input_cost_per_token_above_200k_tokens", + label: "Input / token (>200k)", + group: "chat", + requestTypeGroups: ["chat", "embedding", "rerank"], + }, + { + key: "output_cost_per_token_above_200k_tokens", + label: "Output / token (>200k)", + group: "chat", + requestTypeGroups: ["chat", "rerank", "audio"], + }, + { key: "cache_creation_input_token_cost", label: "Cache creation / token", group: "chat", requestTypeGroups: ["chat"] }, + { key: "cache_read_input_token_cost", label: "Cache read / token", group: "chat", requestTypeGroups: ["chat"] }, + { + key: "cache_creation_input_token_cost_above_200k_tokens", + label: "Cache creation / token (>200k)", + group: "chat", + requestTypeGroups: ["chat"], + }, + { key: "cache_read_input_token_cost_above_200k_tokens", label: "Cache read / token (>200k)", group: "chat", requestTypeGroups: ["chat"] }, + { key: "cache_creation_input_token_cost_above_1hr", label: "Cache creation / token (>1hr)", group: "chat", requestTypeGroups: ["chat"] }, + { + key: "cache_creation_input_token_cost_above_1hr_above_200k_tokens", + label: "Cache creation / token (>1hr, >200k)", + group: "chat", + requestTypeGroups: ["chat"], + }, + { key: "cache_read_input_token_cost_priority", label: "Cache read / token (priority)", group: "chat", requestTypeGroups: ["chat"] }, + { key: "search_context_cost_per_query", label: "Search context / query", group: "chat", requestTypeGroups: ["chat", "rerank"] }, + { key: "code_interpreter_cost_per_session", label: "Code interpreter / session", group: "chat", requestTypeGroups: ["chat"] }, + // Audio fields + { key: "input_cost_per_character", label: "Input / character", group: "audio", requestTypeGroups: ["audio"] }, + { key: "input_cost_per_audio_token", label: "Input / audio token", group: "audio", requestTypeGroups: ["audio"] }, + { key: "input_cost_per_audio_per_second", label: "Input / audio second", group: "audio", requestTypeGroups: ["audio"] }, + { + key: "input_cost_per_audio_per_second_above_128k_tokens", + label: "Input / audio second (>128k)", + group: "audio", + requestTypeGroups: ["audio"], + }, + { key: "input_cost_per_second", label: "Input / second", group: "audio", requestTypeGroups: ["audio", "video"] }, + { key: "output_cost_per_audio_token", label: "Output / audio token", group: "audio", requestTypeGroups: ["audio"] }, + { key: "output_cost_per_second", label: "Output / second", group: "audio", requestTypeGroups: ["audio", "video"] }, + { key: "cache_creation_input_audio_token_cost", label: "Cache creation / audio token", group: "audio", requestTypeGroups: ["audio"] }, + // Image fields + { key: "input_cost_per_image_token", label: "Input / image token", group: "image", requestTypeGroups: ["image"] }, + { key: "input_cost_per_image", label: "Input / image", group: "image", requestTypeGroups: ["image"] }, + { key: "input_cost_per_image_above_128k_tokens", label: "Input / image (>128k)", group: "image", requestTypeGroups: ["image"] }, + { key: "input_cost_per_pixel", label: "Input / pixel", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image_token", label: "Output / image token", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image", label: "Output / image", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_pixel", label: "Output / pixel", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image_premium_image", label: "Output / image (premium)", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image_above_512_and_512_pixels", label: "Output / image (>512px)", group: "image", requestTypeGroups: ["image"] }, + { + key: "output_cost_per_image_above_512_and_512_pixels_and_premium_image", + label: "Output / image (>512px, premium)", + group: "image", + requestTypeGroups: ["image"], + }, + { + key: "output_cost_per_image_above_1024_and_1024_pixels", + label: "Output / image (>1024px)", + group: "image", + requestTypeGroups: ["image"], + }, + { + key: "output_cost_per_image_above_1024_and_1024_pixels_and_premium_image", + label: "Output / image (>1024px, premium)", + group: "image", + requestTypeGroups: ["image"], + }, + { key: "output_cost_per_image_low_quality", label: "Output / image (low quality)", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image_medium_quality", label: "Output / image (medium quality)", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image_high_quality", label: "Output / image (high quality)", group: "image", requestTypeGroups: ["image"] }, + { key: "output_cost_per_image_auto_quality", label: "Output / image (auto quality)", group: "image", requestTypeGroups: ["image"] }, + { key: "cache_read_input_image_token_cost", label: "Cache read / image token", group: "image", requestTypeGroups: ["image"] }, + // Video fields + { key: "input_cost_per_video_per_second", label: "Input / video second", group: "video", requestTypeGroups: ["video"] }, + { + key: "input_cost_per_video_per_second_above_128k_tokens", + label: "Input / video second (>128k)", + group: "video", + requestTypeGroups: ["video"], + }, + { key: "output_cost_per_video_per_second", label: "Output / video second", group: "video", requestTypeGroups: ["video"] }, +] as const; + +export type PricingFieldKey = (typeof PRICING_FIELDS)[number]["key"]; +export type FieldErrors = Partial>; + +type ScopeRoot = "global" | "virtual_key"; + +export interface FormState { + name: string; + scopeRoot: ScopeRoot; + virtualKeyID: string; + providerID: string; + providerKeyID: string; + matchType: PricingOverrideMatchType; + pattern: string; + requestTypes: RequestType[]; + pricingValues: Partial>; +} + +export const defaultFormState: FormState = { + name: "", + scopeRoot: "global", + virtualKeyID: "", + providerID: "", + providerKeyID: "", + matchType: "exact", + pattern: "", + requestTypes: [], + pricingValues: {}, +}; + +export const fieldLabelByKey = Object.fromEntries(PRICING_FIELDS.map((field) => [field.key, field.label])) as Record< + PricingFieldKey, + string +>; +export const patchKeys = PRICING_FIELDS.map((field) => field.key) as PricingFieldKey[]; + +export function patternError(matchType: PricingOverrideMatchType, pattern: string): string | undefined { + const trimmed = pattern.trim(); + if (!trimmed) return "Pattern is required"; + if (matchType === "exact") { + if (trimmed.includes("*")) return "Exact pattern cannot contain *"; + } else if (matchType === "wildcard") { + const starCount = (trimmed.match(/\*/g) || []).length; + if (starCount === 0) return "Wildcard pattern must end with * (example: gpt-5*)"; + if (starCount > 1) return "Wildcard pattern can include only one *"; + if (!trimmed.endsWith("*")) return "Wildcard supports prefix-only trailing *"; + } + return undefined; +} + +export function buildPatchFromForm(form: FormState): { patch: PricingOverridePatch; errors: FieldErrors } { + const errors: FieldErrors = {}; + const patch: PricingOverridePatch = {}; + + for (const key of patchKeys) { + const raw = form.pricingValues[key]; + if (raw == null || raw.trim() === "") continue; + const parsed = Number(raw); + if (!Number.isFinite(parsed)) { + errors[key] = "Must be a number"; + continue; + } + if (parsed < 0) { + errors[key] = "Must be >= 0"; + continue; + } + (patch as Record)[key] = parsed; + } + + return { patch, errors }; +} + +function toFormState(override: PricingOverride): FormState { + const values: Partial> = {}; + let parsedPatch: Record = {}; + try { + if (override.pricing_patch) parsedPatch = JSON.parse(override.pricing_patch); + } catch { + // malformed patch β€” leave values empty + } + for (const key of patchKeys) { + const val = parsedPatch[key]; + if (typeof val === "number") values[key] = String(val); + } + const scopeKind = resolveScopeKind(override); + + const scopeRoot: ScopeRoot = + scopeKind === "virtual_key" || scopeKind === "virtual_key_provider" || scopeKind === "virtual_key_provider_key" + ? "virtual_key" + : "global"; + + return { + name: override.name ?? "", + scopeRoot, + virtualKeyID: override.virtual_key_id ?? "", + providerID: override.provider_id ?? "", + providerKeyID: override.provider_key_id ?? "", + matchType: override.match_type, + pattern: override.pattern, + requestTypes: override.request_types ?? [], + pricingValues: values, + }; +} + +function resolveScopeKind(override: PricingOverride): PricingOverrideScopeKind { + if ( + override.scope_kind === "global" || + override.scope_kind === "provider" || + override.scope_kind === "provider_key" || + override.scope_kind === "virtual_key" || + override.scope_kind === "virtual_key_provider" || + override.scope_kind === "virtual_key_provider_key" + ) { + return override.scope_kind; + } + if (override.virtual_key_id) { + if (override.provider_key_id) return "virtual_key_provider_key"; + if (override.provider_id) return "virtual_key_provider"; + return "virtual_key"; + } + if (override.provider_key_id) return "provider_key"; + if (override.provider_id) return "provider"; + return "global"; +} + +function deriveScopeKind(form: FormState): PricingOverrideScopeKind { + if (form.scopeRoot === "virtual_key") { + if (form.providerKeyID) return "virtual_key_provider_key"; + if (form.providerID) return "virtual_key_provider"; + return "virtual_key"; + } + if (form.providerKeyID) return "provider_key"; + if (form.providerID) return "provider"; + return "global"; +} + +export function patchSummary(override: PricingOverride): string { + let parsed: Record = {}; + try { + if (override.pricing_patch) parsed = JSON.parse(override.pricing_patch); + } catch { + // ignore + } + const keys = Object.keys(parsed) as PricingFieldKey[]; + if (keys.length === 0) return "None"; + const labels = keys.map((key) => fieldLabelByKey[key] || key); + if (labels.length <= 2) return labels.join(", "); + return `${labels.slice(0, 2).join(", ")} +${labels.length - 2} more`; +} + +export function renderFields( + fields: ReadonlyArray<{ key: PricingFieldKey; label: string }>, + form: FormState, + setForm: Dispatch>, + errors: FieldErrors, + onFieldChange?: () => void, +) { + return ( +
+ {fields.map((field) => ( +
+ + { + onFieldChange?.(); + setForm((prev) => ({ + ...prev, + pricingValues: { ...prev.pricingValues, [field.key]: e.target.value }, + })); + }} + /> + {errors[field.key] &&

{errors[field.key]}

} +
+ ))} +
+ ); +} + +interface PricingOverrideDrawerProps { + open: boolean; + onOpenChange: (open: boolean) => void; + editingOverride?: PricingOverride | null; + scopeLock?: { + scopeKind: PricingOverrideScopeKind; + virtualKeyID?: string; + providerID?: string; + providerKeyID?: string; + label?: string; + }; + onSaved?: () => void; +} + +function isCompleteScopeLock(scopeLock?: PricingOverrideDrawerProps["scopeLock"]): boolean { + if (!scopeLock) return false; + switch (scopeLock.scopeKind) { + case "global": + return true; + case "provider": + return Boolean(scopeLock.providerID); + case "provider_key": + return Boolean(scopeLock.providerKeyID); + case "virtual_key": + return Boolean(scopeLock.virtualKeyID); + case "virtual_key_provider": + return Boolean(scopeLock.virtualKeyID && scopeLock.providerID); + case "virtual_key_provider_key": + return Boolean(scopeLock.virtualKeyID && scopeLock.providerID && scopeLock.providerKeyID); + default: + return false; + } +} + +export default function PricingOverrideSheet({ open, onOpenChange, editingOverride, scopeLock, onSaved }: PricingOverrideDrawerProps) { + const { data: providersData, isLoading: isProvidersLoading, error: providersError } = useGetProvidersQuery(); + const { data: virtualKeysData, isLoading: isVirtualKeysLoading, error: virtualKeysError } = useGetVirtualKeysQuery(); + const { data: allKeysData = [] } = useGetAllKeysQuery(); + const [createOverride, { isLoading: isCreating }] = useCreatePricingOverrideMutation(); + const [updateOverride, { isLoading: isPatching }] = useUpdatePricingOverrideMutation(); + + const [form, setForm] = useState(defaultFormState); + const [jsonPatch, setJSONPatch] = useState(""); + const [jsonError, setJSONError] = useState(); + const jsonEditingRef = useRef(false); + const prevOpenRef = useRef(false); + const [requestTypePopoverOpen, setRequestTypePopoverOpen] = useState(false); + const shouldLockScope = useMemo(() => !editingOverride && isCompleteScopeLock(scopeLock), [editingOverride, scopeLock]); + + const isSaving = isCreating || isPatching; + const providers = useMemo(() => (providersError ? [] : (providersData ?? [])), [providersData, providersError]); + const virtualKeys = useMemo(() => (virtualKeysError ? [] : (virtualKeysData?.virtual_keys ?? [])), [virtualKeysData, virtualKeysError]); + + const providerKeyOptions = useMemo( + () => + allKeysData.map((key) => ({ + id: key.key_id, + providerName: key.provider, + label: key.name || key.key_id, + })), + [allKeysData], + ); + const providerScopedKeyOptions = useMemo( + () => providerKeyOptions.filter((key) => key.providerName === form.providerID), + [providerKeyOptions, form.providerID], + ); + + // Hydrate the form only when the sheet transitions from closed β†’ open. + // This prevents providerKeyOptions refetches from resetting unsaved edits. + useEffect(() => { + const wasOpen = prevOpenRef.current; + prevOpenRef.current = open; + if (!open || wasOpen) return; + + jsonEditingRef.current = false; + setJSONError(undefined); + if (editingOverride) { + const state = toFormState(editingOverride); + // For provider_key scopes, provider_id is not stored in the DB (it's implicit from + // the key). Derive it from providerKeyOptions so the provider selector renders and + // the filtered key list shows the pre-selected key correctly. + if (!state.providerID && state.providerKeyID) { + const match = providerKeyOptions.find((k) => k.id === state.providerKeyID); + if (match) state.providerID = match.providerName; + } + setForm(state); + return; + } + if (shouldLockScope && scopeLock) { + const scopedForm: FormState = { + ...defaultFormState, + virtualKeyID: scopeLock.virtualKeyID ?? "", + providerID: scopeLock.providerID ?? "", + providerKeyID: scopeLock.providerKeyID ?? "", + scopeRoot: + scopeLock.scopeKind === "virtual_key" || + scopeLock.scopeKind === "virtual_key_provider" || + scopeLock.scopeKind === "virtual_key_provider_key" + ? "virtual_key" + : "global", + }; + setForm(scopedForm); + return; + } + setForm(defaultFormState); + }, [open, editingOverride, scopeLock, shouldLockScope, providerKeyOptions]); + + // When providerKeyOptions loads after the sheet is already open in edit mode, + // backfill the derived providerID without resetting the rest of the form. + useEffect(() => { + if (!open || !editingOverride) return; + setForm((prev) => { + if (prev.providerID || !prev.providerKeyID) return prev; + const match = providerKeyOptions.find((k) => k.id === prev.providerKeyID); + if (!match) return prev; + return { ...prev, providerID: match.providerName }; + }); + }, [providerKeyOptions, open, editingOverride]); + + const resolvedScopeKind = useMemo(() => { + if (shouldLockScope && scopeLock?.scopeKind) return scopeLock.scopeKind; + return deriveScopeKind(form); + }, [scopeLock, shouldLockScope, form]); + + const resolvedVirtualKeyID = useMemo(() => { + if (shouldLockScope) return scopeLock?.virtualKeyID; + return form.scopeRoot === "virtual_key" ? form.virtualKeyID || undefined : undefined; + }, [scopeLock, shouldLockScope, form.scopeRoot, form.virtualKeyID]); + + const resolvedProviderID = useMemo(() => { + if (shouldLockScope) return scopeLock?.providerID; + return form.providerID || undefined; + }, [scopeLock, shouldLockScope, form.providerID]); + + const resolvedProviderKeyID = useMemo(() => { + if (shouldLockScope) return scopeLock?.providerKeyID; + return form.providerKeyID || undefined; + }, [scopeLock, shouldLockScope, form.providerKeyID]); + + const pricingFieldErrors = useMemo(() => { + const errors: FieldErrors = {}; + for (const key of patchKeys) { + const raw = form.pricingValues[key]; + if (!raw || raw.trim() === "") continue; + const parsed = Number(raw); + if (!Number.isFinite(parsed)) errors[key] = "Must be a number"; + else if (parsed < 0) errors[key] = "Must be >= 0"; + } + return errors; + }, [form.pricingValues]); + + useEffect(() => { + if (!jsonEditingRef.current) { + const { patch } = buildPatchFromForm(form); + const json = Object.keys(patch).length > 0 ? JSON.stringify(patch, null, 2) : ""; + setJSONPatch(json); + setJSONError(undefined); + } + }, [form]); + + const handleJSONChange = useCallback((value: string) => { + jsonEditingRef.current = true; + setJSONPatch(value); + const trimmed = value.trim(); + if (!trimmed) { + setJSONError(undefined); + setForm((prev) => ({ ...prev, pricingValues: {} })); + return; + } + try { + const parsed = JSON.parse(trimmed); + if (parsed == null || typeof parsed !== "object" || Array.isArray(parsed)) { + setJSONError("Patch must be a JSON object"); + return; + } + const pricingValues: Partial> = {}; + for (const [key, val] of Object.entries(parsed)) { + if (!patchKeys.includes(key as PricingFieldKey)) { + setJSONError(`Unknown field: ${key}`); + return; + } + if (typeof val !== "number" || Number.isNaN(val) || val < 0) { + setJSONError(`${key} must be a non-negative number`); + return; + } + pricingValues[key as PricingFieldKey] = String(val); + } + setJSONError(undefined); + setForm((prev) => ({ ...prev, pricingValues })); + } catch { + setJSONError("Invalid JSON"); + } + }, []); + + const handleFieldChange = useCallback(() => { + jsonEditingRef.current = false; + }, []); + + const handleCloseDrawer = () => { + onOpenChange(false); + setRequestTypePopoverOpen(false); + }; + + const toggleRequestType = (requestType: RequestType) => { + setForm((prev) => ({ + ...prev, + requestTypes: prev.requestTypes.includes(requestType) + ? prev.requestTypes.filter((item) => item !== requestType) + : [...prev.requestTypes, requestType], + })); + }; + + const handleSave = async () => { + if (!form.name.trim()) { + toast.error("Name is required"); + return; + } + + if ( + (resolvedScopeKind === "virtual_key" || + resolvedScopeKind === "virtual_key_provider" || + resolvedScopeKind === "virtual_key_provider_key") && + !resolvedVirtualKeyID + ) { + toast.error("Virtual key is required"); + return; + } + if ((resolvedScopeKind === "provider" || resolvedScopeKind === "virtual_key_provider") && !resolvedProviderID) { + toast.error("Provider is required"); + return; + } + if (resolvedScopeKind === "provider_key" && !resolvedProviderKeyID) { + toast.error("Provider key is required"); + return; + } + if (resolvedScopeKind === "virtual_key_provider_key" && (!resolvedProviderID || !resolvedProviderKeyID)) { + toast.error("Provider and provider key are required"); + return; + } + + const pError = patternError(form.matchType, form.pattern); + if (pError) { + toast.error(pError); + return; + } + + if (form.requestTypes.length === 0) { + toast.error("At least one request type must be selected"); + return; + } + + if (jsonError) { + toast.error("Fix the JSON error before saving"); + return; + } + + const { patch, errors: pricingErrors } = buildPatchFromForm(form); + const firstPricingError = Object.values(pricingErrors)[0]; + if (firstPricingError) { + toast.error(firstPricingError); + return; + } + if (Object.keys(patch).length === 0) { + toast.error("At least one pricing field must be overridden"); + return; + } + + let scopedVirtualKeyID: string | undefined; + let scopedProviderID: string | undefined; + let scopedProviderKeyID: string | undefined; + + switch (resolvedScopeKind) { + case "global": + break; + case "provider": + scopedProviderID = resolvedProviderID; + break; + case "provider_key": + scopedProviderKeyID = resolvedProviderKeyID; + break; + case "virtual_key": + scopedVirtualKeyID = resolvedVirtualKeyID; + break; + case "virtual_key_provider": + scopedVirtualKeyID = resolvedVirtualKeyID; + scopedProviderID = resolvedProviderID; + break; + case "virtual_key_provider_key": + scopedVirtualKeyID = resolvedVirtualKeyID; + scopedProviderID = resolvedProviderID; + scopedProviderKeyID = resolvedProviderKeyID; + break; + } + + const requestPayload: CreatePricingOverrideRequest = { + name: form.name.trim(), + scope_kind: resolvedScopeKind, + virtual_key_id: scopedVirtualKeyID, + provider_id: scopedProviderID, + provider_key_id: scopedProviderKeyID, + match_type: form.matchType, + pattern: form.pattern.trim(), + request_types: form.requestTypes.length > 0 ? form.requestTypes : [], + patch, + }; + + try { + if (editingOverride) { + await updateOverride({ id: editingOverride.id, data: requestPayload }).unwrap(); + toast.success("Pricing override updated"); + } else { + await createOverride(requestPayload).unwrap(); + toast.success("Pricing override created"); + } + handleCloseDrawer(); + onSaved?.(); + } catch (error) { + toast.error("Failed to save pricing override", { description: getErrorMessage(error) }); + } + }; + + return ( + (o ? onOpenChange(true) : handleCloseDrawer())}> + + + {editingOverride ? "Edit Pricing Override" : "Create Pricing Override"} + + +
+
+
+ + setForm((prev) => ({ ...prev, name: e.target.value }))} + /> +
+ + {shouldLockScope && scopeLock ? ( +
+ + +
+ ) : ( + <> +
+ + +
+ + {form.scopeRoot === "virtual_key" && ( +
+ + + {virtualKeysError ? ( +

Failed to load virtual keys: {getErrorMessage(virtualKeysError)}

+ ) : null} +
+ )} + +
+
+ + + {providersError ? ( +

Failed to load providers: {getErrorMessage(providersError)}

+ ) : null} +
+ + {form.providerID ? ( +
+ + +
+ ) : ( +
+ )} +
+ + )} +
+ +
+
+
+ + +
+
+ + setForm((prev) => ({ ...prev, pattern: e.target.value }))} + placeholder={form.matchType === "exact" ? "e.g., gpt-4o" : "e.g., gpt-4*"} + /> +
+
+
+ +
+ + + + + + e.stopPropagation()}> +
e.stopPropagation()}> + {REQUEST_TYPE_GROUPS.map((group) => ( +
+
{group.label}
+ {group.types.map((requestType) => { + const checked = form.requestTypes.includes(requestType); + return ( + + ); + })} +
+ ))} +
+
+ +
+
+
+
+ +
+ + { + handleFieldChange(); + setForm((prev) => ({ ...prev, pricingValues: { ...prev.pricingValues, [key]: value } })); + }} + onFieldInteraction={handleFieldChange} + /> +
+ +
+ +
+ +
+ {jsonError &&

{jsonError}

} +
+
+ +
+ + +
+ + + ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/pricingOverridesEmptyState.tsx b/ui/app/workspace/custom-pricing/overrides/pricingOverridesEmptyState.tsx new file mode 100644 index 0000000000..52c6dae93b --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/pricingOverridesEmptyState.tsx @@ -0,0 +1,45 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { ArrowUpRight, SlidersHorizontal } from "lucide-react"; + +const PRICING_OVERRIDES_DOCS_URL = "https://docs.getbifrost.ai/features/governance/custom-pricing"; + +interface PricingOverridesEmptyStateProps { + onCreateClick: () => void; +} + +export function PricingOverridesEmptyState({ onCreateClick }: PricingOverridesEmptyStateProps) { + return ( +
+
+ +
+
+

Pricing overrides customize cost tracking per scope

+
+ Define custom per-token prices for specific providers, keys, or virtual keys to accurately reflect your negotiated rates. +
+
+ + +
+
+
+ ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx b/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx new file mode 100644 index 0000000000..006aaf75d1 --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx @@ -0,0 +1,390 @@ +"use client"; + +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alertDialog"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; +import { + getErrorMessage, + useDeletePricingOverrideMutation, + useGetPricingOverridesQuery, + useGetProvidersQuery, + useGetVirtualKeysQuery, +} from "@/lib/store"; +import { useGetAllKeysQuery } from "@/lib/store/apis/providersApi"; +import { ProviderIconType, RenderProviderIcon } from "@/lib/constants/icons"; +import { getProviderLabel } from "@/lib/constants/logs"; +import { PricingOverride, PricingOverrideScopeKind } from "@/lib/types/governance"; +import { useDebouncedValue } from "@/hooks/useDebounce"; +import { Input } from "@/components/ui/input"; +import { ChevronLeft, ChevronRight, Edit, Plus, Search, Trash2 } from "lucide-react"; +import { useSearchParams } from "next/navigation"; +import { useEffect, useMemo, useState } from "react"; +import { toast } from "sonner"; +import PricingOverrideSheet from "./pricingOverrideSheet"; +import { PricingOverridesEmptyState } from "./pricingOverridesEmptyState"; + +type ScopeFilter = "all" | PricingOverrideScopeKind; + +function parseScopeKind(value: string | null): ScopeFilter { + if ( + value === "global" || + value === "provider" || + value === "provider_key" || + value === "virtual_key" || + value === "virtual_key_provider" || + value === "virtual_key_provider_key" + ) { + return value; + } + return "all"; +} + +// Returns the top-level scope label: "Global" or the virtual key name. +function scopeLabel(override: PricingOverride, virtualKeyMap: Map): string { + const scopeKind = resolveScopeKind(override); + if (override.virtual_key_id && scopeKind.startsWith("virtual_key")) { + return "Virtual Key"; + } + return "Global"; +} + +// Returns the key label for the override, or "-" when no specific key is scoped. +function keyLabel(override: PricingOverride, keyLabelMap: Map): string { + if (!override.provider_key_id) { + if (!override.provider_id) return "-"; + return "All Keys" + }; + return keyLabelMap.get(override.provider_key_id) || override.provider_key_id; +} + +// Returns the provider label for the override, or "-" if not applicable. +function providerLabel(override: PricingOverride, providerMap: Map, keyProviderMap: Map): string { + const scopeKind = resolveScopeKind(override); + switch (scopeKind) { + case "provider": + case "virtual_key_provider": + return providerMap.get(override.provider_id || "") || override.provider_id || "-"; + case "provider_key": + case "virtual_key_provider_key": { + const keyID = override.provider_key_id || ""; + return providerMap.get(keyProviderMap.get(keyID) || "") || keyProviderMap.get(keyID) || "-"; + } + default: + return "-"; + } +} + +function resolveScopeKind(override: PricingOverride): PricingOverrideScopeKind { + if ( + override.scope_kind === "global" || + override.scope_kind === "provider" || + override.scope_kind === "provider_key" || + override.scope_kind === "virtual_key" || + override.scope_kind === "virtual_key_provider" || + override.scope_kind === "virtual_key_provider_key" + ) { + return override.scope_kind; + } + if (override.virtual_key_id) { + if (override.provider_key_id) return "virtual_key_provider_key"; + if (override.provider_id) return "virtual_key_provider"; + return "virtual_key"; + } + if (override.provider_key_id) return "provider_key"; + if (override.provider_id) return "provider"; + return "global"; +} + +const PAGE_SIZE = 25; + +export default function ScopedPricingOverridesView() { + const searchParams = useSearchParams(); + + const [scopeKind, setScopeKind] = useState(() => parseScopeKind(searchParams.get("scope_kind"))); + const [virtualKeyID, setVirtualKeyID] = useState(() => (searchParams.get("virtual_key_id") || "").trim()); + const [providerID, setProviderID] = useState(() => (searchParams.get("provider_id") || "").trim()); + const [providerKeyID, setProviderKeyID] = useState(() => (searchParams.get("provider_key_id") || "").trim()); + + const [search, setSearch] = useState(""); + const [offset, setOffset] = useState(0); + const debouncedSearch = useDebouncedValue(search, 300); + + useEffect(() => { + setScopeKind(parseScopeKind(searchParams.get("scope_kind"))); + setVirtualKeyID((searchParams.get("virtual_key_id") || "").trim()); + setProviderID((searchParams.get("provider_id") || "").trim()); + setProviderKeyID((searchParams.get("provider_key_id") || "").trim()); + }, [searchParams]); + + // Reset to first page when filters or search change + useEffect(() => { + setOffset(0); + }, [scopeKind, virtualKeyID, providerID, providerKeyID, debouncedSearch]); + + const queryArgs = useMemo(() => ({ + scopeKind: scopeKind === "all" ? undefined : scopeKind, + virtualKeyID: virtualKeyID || undefined, + providerID: providerID || undefined, + providerKeyID: providerKeyID || undefined, + limit: PAGE_SIZE, + offset, + search: debouncedSearch || undefined, + }), [scopeKind, virtualKeyID, providerID, providerKeyID, offset, debouncedSearch]); + + const { data, isLoading, error } = useGetPricingOverridesQuery(queryArgs); + + // Snap offset back when total shrinks past current page + const totalCount = data?.total_count ?? 0; + useEffect(() => { + if (!data || offset < totalCount) return; + setOffset(totalCount === 0 ? 0 : Math.floor((totalCount - 1) / PAGE_SIZE) * PAGE_SIZE); + }, [totalCount, offset]); + const { data: providersData } = useGetProvidersQuery(); + const { data: virtualKeysData } = useGetVirtualKeysQuery(); + const { data: allKeysData = [] } = useGetAllKeysQuery(); + const [deleteOverride, { isLoading: isDeleting }] = useDeletePricingOverrideMutation(); + + useEffect(() => { + if (error) { + toast.error("Failed to load pricing overrides", { description: getErrorMessage(error) }); + } + }, [error]); + + const [isDrawerOpen, setIsDrawerOpen] = useState(false); + const [editingOverride, setEditingOverride] = useState(null); + const [deleteTarget, setDeleteTarget] = useState(null); + + const rows = data?.pricing_overrides ?? []; + const providers = useMemo(() => providersData ?? [], [providersData]); + const virtualKeys = useMemo(() => virtualKeysData?.virtual_keys ?? [], [virtualKeysData]); + + const providerMap = useMemo(() => new Map(providers.map((provider) => [provider.name, provider.name])), [providers]); + const providerKeyOptions = useMemo( + () => + allKeysData.map((key) => ({ + id: key.key_id, + label: key.name || key.key_id, + providerName: key.provider, + })), + [allKeysData], + ); + const providerKeyProviderMap = useMemo( + () => new Map(providerKeyOptions.map((key) => [key.id, key.providerName])), + [providerKeyOptions], + ); + const providerKeyLabelMap = useMemo( + () => new Map(providerKeyOptions.map((key) => [key.id, key.label])), + [providerKeyOptions], + ); + const virtualKeyMap = useMemo(() => new Map(virtualKeys.map((vk) => [vk.id, vk.name])), [virtualKeys]); + + const createScopeLock = useMemo(() => { + if (scopeKind === "all") return undefined; + return { + scopeKind, + virtualKeyID: virtualKeyID || undefined, + providerID: providerID || undefined, + providerKeyID: providerKeyID || undefined, + label: `${scopeKind}${virtualKeyID || providerID || providerKeyID ? " (filtered)" : ""}`, + }; + }, [scopeKind, virtualKeyID, providerID, providerKeyID]); + + const openCreateDrawer = () => { + setEditingOverride(null); + setIsDrawerOpen(true); + }; + + const openEditDrawer = (override: PricingOverride) => { + setEditingOverride(override); + setIsDrawerOpen(true); + }; + + const handleDeleteConfirm = async () => { + if (!deleteTarget) return; + try { + await deleteOverride(deleteTarget.id).unwrap(); + toast.success("Pricing override deleted"); + setDeleteTarget(null); + } catch (deleteError) { + toast.error("Failed to delete pricing override", { description: getErrorMessage(deleteError) }); + } + }; + + const hasActiveFilters = debouncedSearch || scopeKind !== "all" || virtualKeyID || providerID || providerKeyID; + + if (!isLoading && !error && totalCount === 0 && !hasActiveFilters) { + return ( + <> + + + + ); + } + + return ( +
+
+
+

Pricing Overrides

+

Set custom rates for any model across global or virtual key scopes, optionally narrowed to a specific provider or key

+
+ +
+ + {/* Search */} +
+ + setSearch(e.target.value)} + className="pl-9" + data-testid="pricing-overrides-search-input" + /> +
+ +
+ {isLoading ? ( +
Loading overrides...
+ ) : error ? ( +
Failed to load pricing overrides. Please try refreshing the page.
+ ) : ( + + + + Name + Scope + Provider + Key + Model + Actions + + + + {rows.length === 0 ? ( + + + No matching pricing overrides found. + + + ) : rows.map((row) => ( + + {row.name || "-"} + + {scopeLabel(row, virtualKeyMap)} + + + {(() => { + const name = providerLabel(row, providerMap, providerKeyProviderMap); + if (name === "-") return -; + return ( +
+ + {getProviderLabel(name)} +
+ ); + })()} +
+ {keyLabel(row, providerKeyLabelMap)} + {row.pattern} + e.stopPropagation()}> +
+ + +
+
+
+ ))} +
+
+ )} +
+ + {/* Pagination */} + {totalCount > 0 && ( +
+

+ Showing {offset + 1}-{Math.min(offset + PAGE_SIZE, totalCount)} of {totalCount} +

+
+ + +
+
+ )} + + + + (!open ? setDeleteTarget(null) : undefined)}> + + + Delete Pricing Override + + Are you sure you want to delete "{deleteTarget?.name}"? This action cannot be undone. + + + + Cancel + { + e.preventDefault(); + void handleDeleteConfirm(); + }} + disabled={isDeleting} + className="bg-destructive hover:bg-destructive/90" + > + {isDeleting ? "Deleting..." : "Delete"} + + + + +
+ ); +} diff --git a/ui/app/workspace/custom-pricing/page.tsx b/ui/app/workspace/custom-pricing/page.tsx index 80932b3325..7c23736637 100644 --- a/ui/app/workspace/custom-pricing/page.tsx +++ b/ui/app/workspace/custom-pricing/page.tsx @@ -1,11 +1,11 @@ "use client" -import PricingConfigView from "@/app/workspace/config/views/pricingConfigView" +import ModelSettingsView from "@/app/workspace/config/views/modelSettingsView" export default function CustomPricingPage() { return (
- +
) } diff --git a/ui/app/workspace/dashboard/components/charts/modelFilterSelect.tsx b/ui/app/workspace/dashboard/components/charts/modelFilterSelect.tsx index fe8d459ac1..c605182525 100644 --- a/ui/app/workspace/dashboard/components/charts/modelFilterSelect.tsx +++ b/ui/app/workspace/dashboard/components/charts/modelFilterSelect.tsx @@ -10,10 +10,16 @@ interface ModelFilterSelectProps { "data-testid"?: string; } -export function ModelFilterSelect({ models, selectedModel, onModelChange, placeholder = "All Models", "data-testid": testId }: ModelFilterSelectProps) { +export function ModelFilterSelect({ + models, + selectedModel, + onModelChange, + placeholder = "All Models", + "data-testid": testId, +}: ModelFilterSelectProps) { return (
@@ -418,7 +421,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => {
)} - {form.auth_type === "oauth" && ( + {(form.auth_type === "oauth" || form.auth_type === "per_user_oauth") && ( <>
@@ -541,6 +544,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => {
)} +
{/* Form Footer */}
@@ -605,6 +609,7 @@ const ClientForm: React.FC = ({ open, onClose, onSaved }) => { authorizeUrl={oauthFlow.authorizeUrl} oauthConfigId={oauthFlow.oauthConfigId} mcpClientId={oauthFlow.mcpClientId} + isPerUserOauth={oauthFlow.isPerUserOauth} /> )} diff --git a/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx b/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx index ee042c0414..609a47b3bb 100644 --- a/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx +++ b/ui/app/workspace/mcp-registry/views/mcpClientSheet.tsx @@ -7,6 +7,8 @@ import { Button } from "@/components/ui/button"; import { Form, FormControl, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; import { HeadersTable } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; +import { MultiSelect } from "@/components/ui/multiSelect"; +import { Select, SelectContent, SelectItem, SelectTrigger } from "@/components/ui/select"; import { Sheet, SheetContent, SheetDescription, SheetHeader, SheetTitle } from "@/components/ui/sheet"; import { Switch } from "@/components/ui/switch"; import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; @@ -14,13 +16,14 @@ import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/comp import { TriStateCheckbox } from "@/components/ui/tristateCheckbox"; import { useToast } from "@/hooks/use-toast"; import { MCP_STATUS_COLORS } from "@/lib/constants/config"; -import { getErrorMessage, useGetCoreConfigQuery, useUpdateMCPClientMutation } from "@/lib/store"; -import { MCPClient } from "@/lib/types/mcp"; +import { getErrorMessage, useGetCoreConfigQuery, useGetVirtualKeysQuery, useUpdateMCPClientMutation } from "@/lib/store"; +import { MCPClient, MCPVKConfig } from "@/lib/types/mcp"; import { mcpClientUpdateSchema, type MCPClientUpdateSchema } from "@/lib/types/schemas"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; import { zodResolver } from "@hookform/resolvers/zod"; -import { ChevronDown, ChevronRight, Info } from "lucide-react"; -import { useEffect, useState } from "react"; +import { ChevronDown, ChevronRight, Info, Plus, Trash2 } from "lucide-react"; +import { useDebouncedValue } from "@/hooks/useDebounce"; +import { useEffect, useMemo, useState } from "react"; import { useForm } from "react-hook-form"; import { CodeEditor } from "@/components/ui/codeEditor"; @@ -47,6 +50,81 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: const { toast } = useToast(); const [expandedTools, setExpandedTools] = useState>(new Set()); + // VK access management β€” search-based dropdown (limit 20), no pagination issue + const [vkSearch, setVKSearch] = useState(""); + const [vkSelectValue, setVKSelectValue] = useState(""); + const debouncedVkSearch = useDebouncedValue(vkSearch, 300); + const { data: vksData } = useGetVirtualKeysQuery({ limit: 20, search: debouncedVkSearch || undefined }); + const allToolNames = useMemo(() => mcpClient.tools?.map((t) => t.name) ?? [], [mcpClient.tools]); + + // Initial VK configs come directly from the MCP client response β€” always complete, no pagination issue. + const initialVKConfigs = useMemo( + () => (mcpClient.vk_configs ?? []).map((vc) => ({ virtual_key_id: vc.virtual_key_id, tools_to_execute: vc.tools_to_execute })), + [mcpClient.vk_configs], + ); + + const [vkConfigs, setVKConfigs] = useState([]); + const [vkConfigsDirty, setVKConfigsDirty] = useState(false); + const [allowedExtraHeadersRaw, setAllowedExtraHeadersRaw] = useState( + (mcpClient.config.allowed_extra_headers || []).join(", "), + ); + // Persists names for newly added VKs so they survive search result changes + const [localVKNames, setLocalVKNames] = useState>({}); + + // Sync vkConfigs when mcpClient changes + useEffect(() => { + setVKConfigs(initialVKConfigs); + setVKConfigsDirty(false); + setLocalVKNames({}); + }, [initialVKConfigs]); + + // Sync allowedExtraHeadersRaw when mcpClient changes + useEffect(() => { + setAllowedExtraHeadersRaw((mcpClient.config.allowed_extra_headers || []).join(", ")); + }, [mcpClient.config.allowed_extra_headers]); + + // Name lookup: server response names β†’ search results β†’ locally cached names (highest priority) + const vkNameByID = useMemo>(() => { + const m: Record = {}; + for (const vc of mcpClient.vk_configs ?? []) m[vc.virtual_key_id] = vc.virtual_key_name; + for (const vk of vksData?.virtual_keys ?? []) m[vk.id] = vk.name; + Object.assign(m, localVKNames); + return m; + }, [mcpClient.vk_configs, vksData, localVKNames]); + + const vkOptions = useMemo( + () => + (vksData?.virtual_keys ?? []) + .filter((vk) => !vkConfigs.some((vc) => vc.virtual_key_id === vk.id)) + .map((vk) => ({ value: vk.id, label: vk.name })), + [vksData, vkConfigs], + ); + + const toolOptions = useMemo( + () => [ + { value: "*", label: "Allow All Tools", description: "Allow all current and future tools" }, + ...allToolNames.map((n) => ({ value: n, label: n })), + ], + [allToolNames], + ); + + const addVKConfig = (vkId: string) => { + const name = vksData?.virtual_keys?.find((vk) => vk.id === vkId)?.name; + if (name) setLocalVKNames((prev) => ({ ...prev, [vkId]: name })); + setVKConfigs((prev) => [...prev, { virtual_key_id: vkId, tools_to_execute: ["*"] }]); + setVKConfigsDirty(true); + }; + + const removeVKConfig = (vkId: string) => { + setVKConfigs((prev) => prev.filter((vc) => vc.virtual_key_id !== vkId)); + setVKConfigsDirty(true); + }; + + const updateVKConfigTools = (vkId: string, tools: string[]) => { + setVKConfigs((prev) => prev.map((vc) => (vc.virtual_key_id === vkId ? { ...vc, tools_to_execute: tools } : vc))); + setVKConfigsDirty(true); + }; + const toggleToolExpanded = (toolName: string) => { setExpandedTools((prev) => { const next = new Set(prev); @@ -66,11 +144,13 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: name: mcpClient.config.name, is_code_mode_client: mcpClient.config.is_code_mode_client || false, is_ping_available: mcpClient.config.is_ping_available === true || mcpClient.config.is_ping_available === undefined, + allow_on_all_virtual_keys: mcpClient.config.allow_on_all_virtual_keys || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], tool_pricing: mcpClient.config.tool_pricing || {}, tool_sync_interval: toolSyncIntervalToMinutes(mcpClient.config.tool_sync_interval), + allowed_extra_headers: mcpClient.config.allowed_extra_headers || [], }, }); @@ -80,11 +160,13 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: name: mcpClient.config.name, is_code_mode_client: mcpClient.config.is_code_mode_client || false, is_ping_available: mcpClient.config.is_ping_available === true || mcpClient.config.is_ping_available === undefined, + allow_on_all_virtual_keys: mcpClient.config.allow_on_all_virtual_keys || false, headers: mcpClient.config.headers, tools_to_execute: mcpClient.config.tools_to_execute || [], tools_to_auto_execute: mcpClient.config.tools_to_auto_execute || [], tool_pricing: mcpClient.config.tool_pricing || {}, tool_sync_interval: toolSyncIntervalToMinutes(mcpClient.config.tool_sync_interval), + allowed_extra_headers: mcpClient.config.allowed_extra_headers || [], }); }, [form, mcpClient]); @@ -96,11 +178,14 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }: name: data.name, is_code_mode_client: data.is_code_mode_client, is_ping_available: data.is_ping_available, + allow_on_all_virtual_keys: data.allow_on_all_virtual_keys, headers: data.headers ?? {}, tools_to_execute: data.tools_to_execute, tools_to_auto_execute: data.tools_to_auto_execute, tool_pricing: data.tool_pricing, tool_sync_interval: data.tool_sync_interval ?? 0, + allowed_extra_headers: data.allowed_extra_headers, + vk_configs: vkConfigsDirty ? vkConfigs : undefined, }, }).unwrap(); @@ -235,7 +320,7 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }:
{/* Client Configuration */}
@@ -660,11 +823,134 @@ export default function MCPClientSheet({ mcpClient, onClose, onSubmitSuccess }:

No tools available

)} + + {mcpClient.tools && mcpClient.tools.length > 0 && ( +
+
+
+
+
Virtual Key Access
+ + + + + + +

Control which virtual keys can use this MCP server and which specific tools they can call.

+
+
+
+
+ {vkOptions.length > 0 && ( + setVKSearch(e.target.value)} + onKeyDown={(e) => e.stopPropagation()} + className="h-7 text-sm" + /> +
+ {vkOptions.length > 0 ? vkOptions.map((opt) => ( + + {opt.label} + + )) : ( +
No virtual keys found
+ )} + + + )} +
+ {form.watch("allow_on_all_virtual_keys") && ( +

+ + Configuring access for a virtual key here overrides the{" "} + Allow on All Virtual Keys setting for that key. +

+ )} +
+ + {vkConfigs.length > 0 ? ( +
+ + + + Virtual Key + Allowed Tools + + + + + {vkConfigs.map((vc) => ( + + {vkNameByID[vc.virtual_key_id] ?? vc.virtual_key_id} + + { + const hadStar = vc.tools_to_execute.includes("*"); + const hasStar = tools.includes("*"); + let next: string[]; + if (!hadStar && hasStar) { + next = ["*"]; + } else if (hadStar && hasStar && tools.length > 1) { + next = tools.filter((t) => t !== "*"); + } else { + next = tools; + } + updateVKConfigTools(vc.virtual_key_id, next); + }} + placeholder={vc.tools_to_execute.includes("*") ? "All tools allowed" : vc.tools_to_execute.length === 0 ? "No tools allowed" : "Select tools..."} + maxCount={3} + className="bg-background dark:bg-input/30 border-input rounded-sm text-foreground hover:bg-accent hover:text-accent-foreground font-normal" + /> + + + + + + ))} + +
+
+ ) : form.watch("allow_on_all_virtual_keys") ? ( +
+

All virtual keys can access this MCP server unless a key has an explicit override.

+
+ ) : ( +
+

No virtual keys have access to this MCP server

+
+ )} +
+ )} - + ); } diff --git a/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx b/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx index b43270f863..f4a32d4e08 100644 --- a/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx +++ b/ui/app/workspace/mcp-registry/views/oauth2Authorizer.tsx @@ -15,6 +15,7 @@ interface OAuth2AuthorizerProps { authorizeUrl: string oauthConfigId: string mcpClientId: string + isPerUserOauth?: boolean } export const OAuth2Authorizer: React.FC = ({ @@ -25,8 +26,9 @@ export const OAuth2Authorizer: React.FC = ({ authorizeUrl, oauthConfigId, mcpClientId, + isPerUserOauth, }) => { - const [status, setStatus] = useState<"pending" | "polling" | "success" | "failed">("pending") + const [status, setStatus] = useState<"confirm" | "pending" | "polling" | "success" | "failed">(isPerUserOauth ? "confirm" : "pending") const [errorMessage, setErrorMessage] = useState(null) const popupRef = useRef(null) const pollIntervalRef = useRef(null) @@ -169,13 +171,19 @@ export const OAuth2Authorizer: React.FC = ({ } }, [checkOAuthStatus]) - // Open popup when dialog opens + // Open popup when dialog opens (skip if waiting for user confirmation) useEffect(() => { if (open && status === "pending") { openPopup() } }, [open, status, openPopup]) + // Handle user confirming per-user OAuth test + const handleConfirmPerUserOAuth = () => { + setStatus("pending") + openPopup() + } + // Cleanup on unmount useEffect(() => { return () => { @@ -187,9 +195,13 @@ export const OAuth2Authorizer: React.FC = ({ }, [stopPolling]) const handleRetry = () => { - setStatus("pending") setErrorMessage(null) - openPopup() + if (isPerUserOauth) { + setStatus("confirm") + } else { + setStatus("pending") + openPopup() + } } const handleCancel = () => { @@ -204,8 +216,9 @@ export const OAuth2Authorizer: React.FC = ({ e.preventDefault()} onEscapeKeyDown={(e) => e.preventDefault()}> - OAuth Authorization + {status === "confirm" ? "Test OAuth Configuration" : "OAuth Authorization"} + {status === "confirm" && "A one-time login is needed to verify your OAuth setup."} {status === "pending" && "Opening authorization window..."} {status === "polling" && "Waiting for authorization..."} {status === "success" && "Authorization successful!"} @@ -214,6 +227,30 @@ export const OAuth2Authorizer: React.FC = ({
+ {status === "confirm" && ( + <> +
+

+ To set up this MCP server, we need to verify that your OAuth configuration is correct and discover the available tools. +

+

+ You will be asked to log in to the OAuth provider. This is a one-time test to confirm the setup works. Your credentials will not be stored or used for any other purpose. +

+

+ Once verified, each user will authenticate individually when they use this MCP server. +

+
+
+ + +
+ + )} + {status === "polling" && ( <> diff --git a/ui/app/workspace/model-limits/views/modelLimitSheet.tsx b/ui/app/workspace/model-limits/views/modelLimitSheet.tsx index 27fcb8f49d..0a761cbba0 100644 --- a/ui/app/workspace/model-limits/views/modelLimitSheet.tsx +++ b/ui/app/workspace/model-limits/views/modelLimitSheet.tsx @@ -12,7 +12,13 @@ import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/comp import { resetDurationOptions } from "@/lib/constants/governance"; import { RenderProviderIcon } from "@/lib/constants/icons"; import { ProviderLabels, ProviderName } from "@/lib/constants/logs"; -import { getErrorMessage, useCreateModelConfigMutation, useGetProvidersQuery, useLazyGetModelsQuery, useUpdateModelConfigMutation } from "@/lib/store"; +import { + getErrorMessage, + useCreateModelConfigMutation, + useGetProvidersQuery, + useLazyGetModelsQuery, + useUpdateModelConfigMutation, +} from "@/lib/store"; import { KnownProvider } from "@/lib/types/config"; import { ModelConfig } from "@/lib/types/governance"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; @@ -31,11 +37,11 @@ interface ModelLimitSheetProps { const formSchema = z.object({ modelName: z.string().min(1, "Model name is required"), provider: z.string().optional(), - budgetMaxLimit: z.string().optional(), + budgetMaxLimit: z.number().nonnegative().optional(), budgetResetDuration: z.string().optional(), - tokenMaxLimit: z.string().optional(), + tokenMaxLimit: z.number().int().nonnegative().optional(), tokenResetDuration: z.string().optional(), - requestMaxLimit: z.string().optional(), + requestMaxLimit: z.number().int().nonnegative().optional(), requestResetDuration: z.string().optional(), }); @@ -91,17 +97,19 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model defaultValues: { modelName: modelConfig?.model_name || "", provider: modelConfig?.provider || "", - budgetMaxLimit: modelConfig?.budget ? String(modelConfig.budget.max_limit) : "", + budgetMaxLimit: modelConfig?.budget?.max_limit ?? undefined, budgetResetDuration: modelConfig?.budget?.reset_duration || "1M", - tokenMaxLimit: modelConfig?.rate_limit?.token_max_limit ? String(modelConfig.rate_limit.token_max_limit) : "", + tokenMaxLimit: modelConfig?.rate_limit?.token_max_limit ?? undefined, tokenResetDuration: modelConfig?.rate_limit?.token_reset_duration || "1h", - requestMaxLimit: modelConfig?.rate_limit?.request_max_limit ? String(modelConfig.rate_limit.request_max_limit) : "", + requestMaxLimit: modelConfig?.rate_limit?.request_max_limit ?? undefined, requestResetDuration: modelConfig?.rate_limit?.request_reset_duration || "1h", }, }); - const parseLimit = (v: string | undefined) => { const n = parseFloat(v ?? ""); return !isNaN(n) && n > 0; }; - const hasAnyLimit = parseLimit(form.watch("budgetMaxLimit")) || parseLimit(form.watch("tokenMaxLimit")) || parseLimit(form.watch("requestMaxLimit")); + const hasAnyLimit = + (form.watch("budgetMaxLimit") !== undefined && form.watch("budgetMaxLimit") !== null) || + (form.watch("tokenMaxLimit") !== undefined && form.watch("tokenMaxLimit") !== null) || + (form.watch("requestMaxLimit") !== undefined && form.watch("requestMaxLimit") !== null); useEffect(() => { if (modelConfig) { @@ -112,11 +120,11 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model form.reset({ modelName: modelConfig.model_name || "", provider: modelConfig.provider || "", - budgetMaxLimit: modelConfig.budget ? String(modelConfig.budget.max_limit) : "", + budgetMaxLimit: modelConfig.budget?.max_limit ?? undefined, budgetResetDuration: modelConfig.budget?.reset_duration || "1M", - tokenMaxLimit: modelConfig.rate_limit?.token_max_limit ? String(modelConfig.rate_limit.token_max_limit) : "", + tokenMaxLimit: modelConfig.rate_limit?.token_max_limit ?? undefined, tokenResetDuration: modelConfig.rate_limit?.token_reset_duration || "1h", - requestMaxLimit: modelConfig.rate_limit?.request_max_limit ? String(modelConfig.rate_limit.request_max_limit) : "", + requestMaxLimit: modelConfig.rate_limit?.request_max_limit ?? undefined, requestResetDuration: modelConfig.rate_limit?.request_reset_duration || "1h", }); } @@ -129,21 +137,20 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model } try { - const budgetMaxLimit = data.budgetMaxLimit ? parseFloat(data.budgetMaxLimit) : undefined; - const tokenMaxLimit = data.tokenMaxLimit ? parseInt(data.tokenMaxLimit) : undefined; - const requestMaxLimit = data.requestMaxLimit ? parseInt(data.requestMaxLimit) : undefined; const provider = data.provider && data.provider.trim() !== "" ? data.provider : undefined; if (isEditing && modelConfig) { const hadBudget = !!modelConfig.budget; - const hasBudget = !!budgetMaxLimit; + const hasBudget = data.budgetMaxLimit !== undefined && data.budgetMaxLimit !== null; const hadRateLimit = !!modelConfig.rate_limit; - const hasRateLimit = !!tokenMaxLimit || !!requestMaxLimit; + const hasRateLimit = + (data.tokenMaxLimit !== undefined && data.tokenMaxLimit !== null) || + (data.requestMaxLimit !== undefined && data.requestMaxLimit !== null); let budgetPayload: { max_limit?: number; reset_duration?: string } | undefined; if (hasBudget) { budgetPayload = { - max_limit: budgetMaxLimit, + max_limit: data.budgetMaxLimit, reset_duration: data.budgetResetDuration || "1M", }; } else if (hadBudget) { @@ -152,18 +159,19 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model let rateLimitPayload: | { - token_max_limit?: number | null; - token_reset_duration?: string | null; - request_max_limit?: number | null; - request_reset_duration?: string | null; - } + token_max_limit?: number | null; + token_reset_duration?: string | null; + request_max_limit?: number | null; + request_reset_duration?: string | null; + } | undefined; if (hasRateLimit) { rateLimitPayload = { - token_max_limit: tokenMaxLimit ?? null, - token_reset_duration: tokenMaxLimit ? data.tokenResetDuration || "1h" : null, - request_max_limit: requestMaxLimit ?? null, - request_reset_duration: requestMaxLimit ? data.requestResetDuration || "1h" : null, + token_max_limit: data.tokenMaxLimit ?? null, + token_reset_duration: data.tokenMaxLimit !== undefined && data.tokenMaxLimit !== null ? data.tokenResetDuration || "1h" : null, + request_max_limit: data.requestMaxLimit ?? null, + request_reset_duration: + data.requestMaxLimit !== undefined && data.requestMaxLimit !== null ? data.requestResetDuration || "1h" : null, }; } else if (hadRateLimit) { rateLimitPayload = {}; @@ -183,20 +191,28 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model await createModelConfig({ model_name: data.modelName, provider, - budget: budgetMaxLimit - ? { - max_limit: budgetMaxLimit, - reset_duration: data.budgetResetDuration || "1M", - } - : undefined, + budget: + data.budgetMaxLimit !== undefined && data.budgetMaxLimit !== null + ? { + max_limit: data.budgetMaxLimit, + reset_duration: data.budgetResetDuration || "1M", + } + : undefined, rate_limit: - tokenMaxLimit || requestMaxLimit + (data.tokenMaxLimit !== undefined && data.tokenMaxLimit !== null) || + (data.requestMaxLimit !== undefined && data.requestMaxLimit !== null) ? { - token_max_limit: tokenMaxLimit, - token_reset_duration: data.tokenResetDuration || "1h", - request_max_limit: requestMaxLimit, - request_reset_duration: data.requestResetDuration || "1h", - } + token_max_limit: data.tokenMaxLimit, + token_reset_duration: + data.tokenMaxLimit !== undefined && data.tokenMaxLimit !== null + ? data.tokenResetDuration || "1h" + : undefined, + request_max_limit: data.requestMaxLimit, + request_reset_duration: + data.requestMaxLimit !== undefined && data.requestMaxLimit !== null + ? data.requestResetDuration || "1h" + : undefined, + } : undefined, }).unwrap(); toast.success("Model limit created successfully"); @@ -212,8 +228,12 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model !open && handleClose()}> { if (isEditing ? form.formState.isDirty : (!!form.watch("modelName") || hasAnyLimit)) e.preventDefault(); }} - onEscapeKeyDown={(e) => { if (isEditing ? form.formState.isDirty : (!!form.watch("modelName") || hasAnyLimit)) e.preventDefault(); }} + onInteractOutside={(e) => { + if (isEditing ? form.formState.isDirty : !!form.watch("modelName") || hasAnyLimit) e.preventDefault(); + }} + onEscapeKeyDown={(e) => { + if (isEditing ? form.formState.isDirty : !!form.watch("modelName") || hasAnyLimit) e.preventDefault(); + }} data-testid="model-limit-sheet" > @@ -235,7 +255,9 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model Provider @@ -303,7 +327,7 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model id="modelBudgetMaxLimit" labelClassName="font-normal" label="Maximum Spend (USD)" - value={field.value || ""} + value={field.value} selectValue={form.watch("budgetResetDuration") || "1M"} onChangeNumber={(value) => field.onChange(value)} onChangeSelect={(value) => form.setValue("budgetResetDuration", value, { shouldDirty: true })} @@ -330,7 +354,7 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model id="modelTokenMaxLimit" labelClassName="font-normal" label="Maximum Tokens" - value={field.value || ""} + value={field.value} selectValue={form.watch("tokenResetDuration") || "1h"} onChangeNumber={(value) => field.onChange(value)} onChangeSelect={(value) => form.setValue("tokenResetDuration", value, { shouldDirty: true })} @@ -350,7 +374,7 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model id="modelRequestMaxLimit" labelClassName="font-normal" label="Maximum Requests" - value={field.value || ""} + value={field.value} selectValue={form.watch("requestResetDuration") || "1h"} onChangeNumber={(value) => field.onChange(value)} onChangeSelect={(value) => form.setValue("requestResetDuration", value, { shouldDirty: true })} @@ -411,12 +435,28 @@ export default function ModelLimitSheet({ modelConfig, onSave, onCancel }: Model - - {(isLoading || !form.formState.isDirty || !form.formState.isValid || !canSubmit || !form.watch("modelName") || !hasAnyLimit) && ( + {(isLoading || + !form.formState.isDirty || + !form.formState.isValid || + !canSubmit || + !form.watch("modelName") || + !hasAnyLimit) && (

{!canSubmit diff --git a/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx b/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx index ef7f3af257..13efd1344e 100644 --- a/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx +++ b/ui/app/workspace/providers/dialogs/addNewCustomProviderSheet.tsx @@ -100,7 +100,6 @@ export function AddCustomProviderSheetContent({ show = true, onClose, onSave }: retry_backoff_initial: 500, retry_backoff_max: 5000, }, - keys: [], }; addProvider(payload) diff --git a/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx b/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx index dcdb06cb52..1c958c915f 100644 --- a/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx +++ b/ui/app/workspace/providers/dialogs/addNewKeySheet.tsx @@ -8,15 +8,16 @@ interface Props { show: boolean; onCancel: () => void; provider: ModelProvider; - keyIndex: number; + keyId: string | null; providerName?: string; } -export default function AddNewKeySheet({ show, onCancel, provider, keyIndex, providerName }: Props) { - const isEditing = keyIndex < provider.keys.length; +export default function AddNewKeySheet({ show, onCancel, provider, keyId, providerName }: Props) { + const isEditing = keyId !== null; const resolvedProviderName = (providerName ?? provider.name).toLowerCase(); const isVLLM = resolvedProviderName === "vllm"; - const entityLabel = isVLLM ? "model" : "key"; + const isOllamaOrSGL = resolvedProviderName === "ollama" || resolvedProviderName === "sgl"; + const entityLabel = isVLLM ? "model" : isOllamaOrSGL ? "server" : "key"; const EntityLabel = entityLabel.charAt(0).toUpperCase() + entityLabel.slice(1); const dialogTitle = isEditing ? `Edit ${entityLabel}` : `Add new ${entityLabel}`; const successMessage = isEditing ? `${EntityLabel} updated successfully` : `${EntityLabel} added successfully`; @@ -32,7 +33,6 @@ export default function AddNewKeySheet({ show, onCancel, provider, keyIndex, pro className="custom-scrollbar p-8" data-testid="key-form" onInteractOutside={(e) => e.preventDefault()} - onEscapeKeyDown={(e) => e.preventDefault()} > @@ -47,7 +47,7 @@ export default function AddNewKeySheet({ show, onCancel, provider, keyIndex, pro

{ toast.success(successMessage); diff --git a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx index 62653fd545..3a7888b0a5 100644 --- a/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/workspace/providers/fragments/apiKeysFormFragment.tsx @@ -4,6 +4,7 @@ import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Button } from "@/components/ui/button"; import { EnvVarInput } from "@/components/ui/envVarInput"; import { FormControl, FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; +import { HeadersTable, type CellRenderParams } from "@/components/ui/headersTable"; import { Input } from "@/components/ui/input"; import { ModelMultiselect } from "@/components/ui/modelMultiselect"; import { Separator } from "@/components/ui/separator"; @@ -20,6 +21,34 @@ import { Control, UseFormReturn } from "react-hook-form"; // Providers that support batch APIs const BATCH_SUPPORTED_PROVIDERS = ["openai", "bedrock", "anthropic", "gemini", "azure"]; +/** Normalize form value (object or legacy JSON string) for the alias map editor. */ +function normalizeAliasesValue( + v: Record | string | undefined | null, +): Record { + if (v == null) { + return {}; + } + if (typeof v === "string") { + const t = v.trim(); + if (!t) { + return {}; + } + try { + const p = JSON.parse(t) as unknown; + if (typeof p === "object" && p !== null && !Array.isArray(p)) { + return Object.fromEntries(Object.entries(p as Record).map(([k, val]) => [k, String(val ?? "")])); + } + } catch { + return {}; + } + return {}; + } + if (typeof v === "object" && !Array.isArray(v)) { + return Object.fromEntries(Object.entries(v).map(([k, val]) => [k, typeof val === "string" ? val : String(val ?? "")])); + } + return {}; +} + interface Props { control: Control; providerName: string; @@ -55,6 +84,9 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { const isAzure = providerName === "azure"; const isReplicate = providerName === "replicate"; const isVLLM = providerName === "vllm"; + const isOllama = providerName === "ollama"; + const isSGL = providerName === "sgl"; + const isKeylessProvider = isOllama || isSGL; const supportsBatchAPI = BATCH_SUPPORTED_PROVIDERS.includes(providerName); // Auth type state for Azure: 'api_key', 'entra_id', or 'default_credential' @@ -63,48 +95,69 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { // Auth type state for Bedrock: 'iam_role', 'explicit', or 'api_key' const [bedrockAuthType, setBedrockAuthType] = useState<"iam_role" | "explicit" | "api_key">("iam_role"); + // Auth type state for Vertex: 'service_account', 'service_account_json', or 'api_key' + const [vertexAuthType, setVertexAuthType] = useState<"service_account" | "service_account_json" | "api_key">("service_account"); + // Detect auth type from existing form values when editing useEffect(() => { if (form.formState.isDirty) return; if (isAzure) { - const clientId = form.getValues("key.azure_key_config.client_id")?.value; - const clientSecret = form.getValues("key.azure_key_config.client_secret")?.value; - const tenantId = form.getValues("key.azure_key_config.tenant_id")?.value; - const apiKey = form.getValues("key.value")?.value; - if (clientId || clientSecret || tenantId) { - setAzureAuthType("entra_id"); - } else if (!apiKey) { - setAzureAuthType("default_credential"); + const clientId = form.getValues("key.azure_key_config.client_id"); + const clientSecret = form.getValues("key.azure_key_config.client_secret"); + const tenantId = form.getValues("key.azure_key_config.tenant_id"); + const apiKey = form.getValues("key.value"); + const hasEntraField = clientId?.value || clientId?.env_var || clientSecret?.value || clientSecret?.env_var || tenantId?.value || tenantId?.env_var; + const hasApiKey = apiKey?.value || apiKey?.env_var; + let detected: "api_key" | "entra_id" | "default_credential" = "api_key"; + if (hasEntraField) { + detected = "entra_id"; + } else if (!hasApiKey) { + detected = "default_credential"; } + setAzureAuthType(detected); + form.setValue("key.azure_key_config._auth_type", detected); } }, [isAzure, form]); useEffect(() => { if (form.formState.isDirty) return; - if (isBedrock) { - const accessKey = form.getValues("key.bedrock_key_config.access_key")?.value; - const secretKey = form.getValues("key.bedrock_key_config.secret_key")?.value; + if (isVertex) { + const authCredentials = form.getValues("key.vertex_key_config.auth_credentials")?.value; + const authCredentialsEnv = form.getValues("key.vertex_key_config.auth_credentials")?.env_var; const apiKey = form.getValues("key.value")?.value; - if (accessKey || secretKey) { - setBedrockAuthType("explicit"); - } else if (apiKey) { - setBedrockAuthType("api_key"); + const apiKeyEnv = form.getValues("key.value")?.env_var; + let detected: "service_account" | "service_account_json" | "api_key" = "service_account"; + if (authCredentials || authCredentialsEnv) { + detected = "service_account_json"; + } else if (apiKey || apiKeyEnv) { + detected = "api_key"; } + setVertexAuthType(detected); + form.setValue("key.vertex_key_config._auth_type", detected); + } + }, [isVertex, form]); + + useEffect(() => { + if (form.formState.isDirty) return; + if (isBedrock) { + const accessKey = form.getValues("key.bedrock_key_config.access_key"); + const secretKey = form.getValues("key.bedrock_key_config.secret_key"); + const apiKey = form.getValues("key.value"); + const hasExplicitCreds = accessKey?.value || accessKey?.env_var || secretKey?.value || secretKey?.env_var; + const hasApiKey = apiKey?.value || apiKey?.env_var; + let detected: "iam_role" | "explicit" | "api_key" = "iam_role"; + if (hasExplicitCreds) { + detected = "explicit"; + } else if (hasApiKey) { + detected = "api_key"; + } + setBedrockAuthType(detected); + form.setValue("key.bedrock_key_config._auth_type", detected); } }, [isBedrock, form]); return (
- {isVertex && ( - - - Authentication Methods - - You can either use service account authentication or API key authentication. Please leave API Key empty when using service - account authentication. - - - )}
- {/* Hide API Key field for Azure when using Entra ID/Default Credential, and for Bedrock when not using API Key auth */} - {!isAzure && !isBedrock && ( + {/* Hide API Key field for providers with dedicated auth tabs */} + {!isAzure && !isBedrock && !isVertex && ( ( - API Key {isVertex ? "(Supported only for gemini and fine-tuned models)" : isVLLM ? "(Optional)" : ""} + API Key {isVLLM ? "(Optional)" : ""} @@ -203,13 +256,39 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { -

Comma-separated list of models this key applies to. Leave blank for all models.

+

+ Select specific models this key applies to, or choose "Allow All Models" to allow all. Leave empty to deny all. +

- + { + const hadStar = (field.value || []).includes("*"); + const hasStar = models.includes("*"); + if (!hadStar && hasStar) { + field.onChange(["*"]); + } else if (hadStar && hasStar && models.length > 1) { + field.onChange(models.filter((m: string) => m !== "*")); + } else { + field.onChange(models); + } + }} + placeholder={ + (field.value || []).includes("*") + ? "All models allowed" + : (field.value || []).length === 0 + ? "No models (deny all)" + : "Search models..." + } + unfiltered={true} + /> @@ -221,7 +300,7 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { render={({ field }) => (
- Blacklisted models + Blocked Models @@ -231,15 +310,78 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) {

- Comma-separated list of models this key must never use. If a model appears in both Allowed Models and here, the - blacklist wins. Leave empty if none. + Models this key must never serve. The denylist always wins β€” if a model appears in both Allowed Models and here, + it is blocked. Select "All Models" to block every model on this key.

- + { + const hadStar = (field.value || []).includes("*"); + const hasStar = models.includes("*"); + if (!hadStar && hasStar) { + field.onChange(["*"]); + } else if (hadStar && hasStar && models.length > 1) { + field.onChange(models.filter((m: string) => m !== "*")); + } else { + field.onChange(models); + } + }} + placeholder={ + (field.value || []).includes("*") + ? "All models blocked" + : (field.value || []).length === 0 + ? "No models blocked" + : "Search models..." + } + unfiltered={true} + /> + + +
+ )} + /> + ( + + Aliases (Optional) + + Map each request model name to the provider's identifier (deployment name, inference profile ID, fine-tuned endpoint ID, + etc.) or just a custom name, e.g. "claude-sonnet-4-5" -> "custom-claude-4.5-sonnet". + + +
+ { + form.clearErrors("key.aliases"); + field.onChange(Object.keys(next).length > 0 ? next : {}); + }} + keyPlaceholder="Request model name" + valuePlaceholder="Deployment / profile / resource ID" + renderValueInput={({ value: cellValue, onChange, placeholder, disabled }: CellRenderParams) => ( + + )} + /> +
@@ -257,6 +399,7 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { value={azureAuthType} onValueChange={(v) => { setAzureAuthType(v as "api_key" | "entra_id" | "default_credential"); + form.setValue("key.azure_key_config._auth_type", v, { shouldDirty: true, shouldValidate: true }); if (v === "entra_id" || v === "default_credential") { // Clear API key when switching away from API Key form.setValue("key.value", undefined, { shouldDirty: true }); @@ -271,15 +414,15 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { }} > + + Default Credential + API Key Entra ID (Service Principal) - - Default Credential -
@@ -412,52 +555,48 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { /> )} - - ( - - Deployments (Required) - JSON object mapping model names to deployment names - -