diff --git a/.agents/skills/expect/SKILL.md b/.agents/skills/expect/SKILL.md new file mode 100644 index 0000000000..4071ef63e7 --- /dev/null +++ b/.agents/skills/expect/SKILL.md @@ -0,0 +1,62 @@ +--- +name: expect +description: Run adversarial browser tests against code changes. Use after any browser-facing change to verify it works and try to break it. Prefer this over raw browser tools (Playwright MCP, chrome tools). +license: MIT +metadata: + author: millionco + version: "2.0.0" +--- + +# Expect + +Adversarial browser testing for code changes. Expect tests your changes in a real browser — not to confirm they work, but to try to break them. + +**Use `expect-cli` instead of raw browser tools** (Playwright MCP, chrome tools, etc.) for verifying changes. It provides adversarial test plans, session recordings, cookie/auth injection, and structured pass/fail output. + +## Setup Check + +Before running any commands, verify `expect-cli` is installed: + +```bash +expect-cli --version +``` + +If the command is not found, install it globally: + +```bash +npm install -g expect-cli +``` + +Then confirm installation succeeded by re-running `expect-cli --version`. Do not proceed until the command resolves. + +## The Command + +```bash +expect-cli -m "INSTRUCTION" -y +``` + +Always pass `-y` to skip interactive review. Always set `EXPECT_BASE_URL` or `--base-url` if the app isn't on `localhost:3000`. Run `expect-cli --help` for all flags. + +## Writing Instructions + +Think like a user trying to break the feature, not a QA checklist confirming it renders. + +**Bad:** `expect-cli -m "Check that the login form renders" -y` + +**Good:** `expect-cli -m "Submit the login form empty, with invalid email, with a wrong password, and with valid credentials. Verify error messages for bad inputs and redirect on success. Check console errors after each." -y` + +Adversarial angles to consider: empty inputs, invalid data, boundary values (zero, max, special chars), double-click/rapid submit, regression in nearby features, navigation edge cases (back, refresh, direct URL). + +## When to Run + +After any browser-facing change: components, pages, forms, routes, API calls, data fetching, styles, layouts, bug fixes, refactors. When in doubt, run it. + +## Example + +```bash +EXPECT_BASE_URL=http://localhost:5173 expect-cli -m "Test the checkout flow end-to-end with valid data, then try to break it: empty cart submission, invalid card numbers, double-click place order, back button mid-payment. Verify error states and console errors." -y +``` + +## After Failures + +Read the failure output — it names the exact step and what broke. Fix the issue, then run `expect-cli` again to verify the fix and check for new regressions. diff --git a/.claude/skills/docs-writer/SKILL.md b/.claude/skills/docs-writer/SKILL.md index 7239b8045d..02da9b6524 100644 --- a/.claude/skills/docs-writer/SKILL.md +++ b/.claude/skills/docs-writer/SKILL.md @@ -6,7 +6,7 @@ allowed-tools: Read, Grep, Glob, Bash, Edit, Write, WebSearch, WebFetch, mcp__co # Bifrost Documentation Writer -Write, update, and review Mintlify MDX documentation for Bifrost features. Performs thorough codebase research across both the Next.js UI and Go backend, validates config.json examples against the schema, and follows established documentation conventions. +Write, update, and review Mintlify MDX documentation for Bifrost features. Performs thorough codebase research across both the React UI and Go backend, validates config.json examples against the schema, and follows established documentation conventions. ## Usage @@ -103,7 +103,7 @@ Read the doc and cross-reference against the current codebase to identify: ### 2a. Explore the UI Code -The UI is a Next.js application. Feature pages live under `ui/app/workspace//`. +The UI is a React + Vite + TanStack Router application. Feature pages live under `ui/app/workspace//`. ```bash # List the feature directory structure @@ -222,7 +222,7 @@ print(json.dumps(defn, indent=2)) - `config_store` - Config store backend (file, postgres) - `logs_store` - Log store backend (file, postgres) - `cluster_config` - Cluster/multinode configuration -- `saml_config` - SAML/SSO configuration +- `scim_config` - SCIM/SSO configuration - `load_balancer_config` - Adaptive load balancer - `guardrails_config` - Guardrails configuration - `plugins` - Plugin configurations @@ -237,7 +237,7 @@ print(json.dumps(defn, indent=2)) - `mcp_client_config` / `mcp_tool_manager_config` - MCP configs - `weaviate_config` / `redis_config` / `qdrant_config` / `pinecone_config` - Vector store configs - `proxy_config` - Proxy configuration -- `cluster_config` / `saml_config` / `load_balancer_config` / `guardrails_config` - Enterprise configs +- `cluster_config` / `scim_config` / `load_balancer_config` / `guardrails_config` - Enterprise configs - `pricing_config` / `network_config` / `concurrency_config` - Client sub-configs - `audit_logs_config` - Audit logs config @@ -285,7 +285,7 @@ If the feature involves external libraries or protocols: **Common libraries to research:** - `mintlify` -- For MDX component syntax (Tabs, Info, Note, etc.) - `mark3labs/mcp-go` -- For MCP-related features -- `next.js` -- For UI architecture context +- `react` -- For UI architecture context - Provider SDKs -- For provider-specific features ### 3b. Use WebSearch for Additional Context @@ -785,7 +785,7 @@ bifrost/ │ ├── contributing/ # Developer contribution guides │ ├── benchmarking/ # Performance benchmarks │ └── changelogs/ # Version changelogs -├── ui/ # Next.js UI application +├── ui/ # React + Vite UI application │ └── app/workspace/ # Feature pages │ ├── providers/ # Provider management │ ├── virtual-keys/ # Virtual key management diff --git a/.claude/skills/e2e-test/SKILL.md b/.claude/skills/e2e-test/SKILL.md index fc40ae54e6..4c0e8b5cca 100644 --- a/.claude/skills/e2e-test/SKILL.md +++ b/.claude/skills/e2e-test/SKILL.md @@ -783,7 +783,7 @@ make run-e2e-headed FLOW= **Environment variables:** - `BASE_URL` - Override app URL (default: http://localhost:3000) - `BIFROST_BASE_URL` - Override Bifrost API URL (default: http://localhost:8080) -- `SKIP_WEB_SERVER=1` - Skip auto-starting Next.js dev server +- `SKIP_WEB_SERVER=1` - Skip auto-starting Vite dev server - `CI=1` - Enable CI mode (retries, serial execution) ## Step 5: Debug Failing Tests diff --git a/.claude/skills/expect b/.claude/skills/expect new file mode 120000 index 0000000000..0cf7d33b54 --- /dev/null +++ b/.claude/skills/expect @@ -0,0 +1 @@ +../../.agents/skills/expect \ No newline at end of file diff --git a/.claude/skills/investigate-issue/SKILL.md b/.claude/skills/investigate-issue/SKILL.md index 63f6ba1178..13289cf419 100644 --- a/.claude/skills/investigate-issue/SKILL.md +++ b/.claude/skills/investigate-issue/SKILL.md @@ -81,7 +81,7 @@ Use the issue's labels and body content to map to codebase areas. The issue temp | Framework | `framework/`, `framework/configstore/`, `framework/logstore/` | `framework/config.go`, `framework/list.go` | | Transports (HTTP) | `transports/bifrost-http/` | `transports/bifrost-http/` | | Plugins | `plugins/` (governance, jsonparser, litellmcompat, etc.) | Plugin-specific `go.mod` files | -| UI (Next.js) | `ui/`, `ui/app/workspace/`, `ui/components/` | Feature-specific workspace pages | +| UI (React) | `ui/`, `ui/app/workspace/`, `ui/components/` | Feature-specific workspace pages | | Docs | `docs/` | `docs/docs.json`, feature-specific `.mdx` files | If the issue body mentions specific providers (e.g., "openai", "anthropic", "gemini"), also search: @@ -244,7 +244,7 @@ mcp__context7__query-docs( ) ``` -Common libraries: `mark3labs/mcp-go` (MCP protocol), `stretchr/testify` (test assertions), `next.js` (UI framework), `playwright` (E2E testing), provider SDKs (OpenAI, Anthropic, etc.) +Common libraries: `mark3labs/mcp-go` (MCP protocol), `stretchr/testify` (test assertions), `react` (UI framework), `playwright` (E2E testing), provider SDKs (OpenAI, Anthropic, etc.) **Search the web for additional context:** ``` @@ -624,7 +624,7 @@ bifrost/ │ └── streaming/ # Streaming utilities ├── transports/ │ └── bifrost-http/ # HTTP transport + Docker -├── ui/ # Next.js UI +├── ui/ # React + Vite UI │ ├── app/workspace/ # Feature pages │ └── components/ # Shared components ├── plugins/ # Go plugins (governance, otel, etc.) diff --git a/.claude/skills/resolve-pr-comments/SKILL.md b/.claude/skills/resolve-pr-comments/SKILL.md index b13e93dd5d..802f15ee00 100644 --- a/.claude/skills/resolve-pr-comments/SKILL.md +++ b/.claude/skills/resolve-pr-comments/SKILL.md @@ -1,7 +1,7 @@ --- name: resolve-pr-comments -description: Resolve all unresolved PR comments interactively. Use when asked to resolve PR comments, address review feedback, handle CodeRabbit comments, or fix PR review issues. Invoked with /resolve-pr-comments or /resolve-pr-comments . +description: Resolve all unresolved PR comments interactively. Makes local edits only—NEVER commits or pushes. Use when asked to resolve PR comments, address review feedback, handle CodeRabbit comments, or fix PR review issues. Invoked with /resolve-pr-comments or /resolve-pr-comments . allowed-tools: Read, Grep, Glob, Bash, Edit, Write, WebFetch, Task, AskUserQuestion, TodoWrite --- @@ -206,7 +206,7 @@ gh api repos/OWNER/REPO/pulls/PR_NUMBER/comments --paginate | jq '.[] | select(. ## Step 5: Execute Actions -**CRITICAL: Do NOT reply to PR comments until changes are pushed to the remote.** The reviewer cannot verify fixes until the code is pushed. Collect all fixes locally first, then push, then reply. +**CRITICAL: Do NOT reply to PR comments until changes are pushed to the remote.** The reviewer cannot verify fixes until the code is pushed. Collect all fixes locally. This skill NEVER commits or pushes—the user handles that manually. ### For FIX: 1. Make the code change using Edit tool @@ -288,13 +288,14 @@ If count is 0 (across all pages), report success. If comments remain: ## Important Notes -1. **NEVER reply "Fixed" until code is pushed** - The reviewer cannot verify fixes until they're on the remote. Make all fixes locally, push, THEN reply. -2. **Always read the file** before suggesting fixes - understand context -3. **Check for existing replies** in the thread before responding -4. **Wait for user approval** on each action - never auto-fix without confirmation -5. **Update tracking file** after each action -6. **Some bots are slow** - CodeRabbit may take minutes to auto-resolve after push -7. **Push code changes** before expecting auto-resolution of FIX actions +1. **NEVER commit or push changes** - This skill only makes local edits. The user handles `git add`, `git commit`, and `git push` themselves. Do not run any git commit or git push commands. +2. **NEVER reply "Fixed" until code is pushed** - The reviewer cannot verify fixes until they're on the remote. Make all fixes locally. Only reply to FIX comments after the user confirms they have pushed (the user pushes manually). +3. **Always read the file** before suggesting fixes - understand context +4. **Check for existing replies** in the thread before responding +5. **Wait for user approval** on each action - never auto-fix without confirmation +6. **Update tracking file** after each action +7. **Some bots are slow** - CodeRabbit may take minutes to auto-resolve after push +8. **User pushes manually** - This skill never commits or pushes; the user must push code changes before expecting auto-resolution of FIX actions ## Error Handling diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 42db6746bd..96c25c7aa8 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -66,7 +66,7 @@ body: - Framework - Transports (HTTP) - Plugins - - UI (Next.js) + - UI (React) - Docs validations: required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index c138cf2a04..0db4fcdf90 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -53,7 +53,7 @@ body: - Framework - Transports (HTTP) - Plugins - - UI (Next.js) + - UI (React) - Docs validations: required: true diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0f339107f8..00ecb04d59 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -21,7 +21,7 @@ Briefly explain the purpose of this PR and the problem it solves. - [ ] Transports (HTTP) - [ ] Providers/Integrations - [ ] Plugins -- [ ] UI (Next.js) +- [ ] UI (React) - [ ] Docs ## How to test diff --git a/.github/workflows/configs/default/config.json b/.github/workflows/configs/default/config.json index c16511cbcc..e3ac85b6a7 100644 --- a/.github/workflows/configs/default/config.json +++ b/.github/workflows/configs/default/config.json @@ -31,6 +31,7 @@ "name": "e2e-openai-key", "value": "env.OPENAI_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], @@ -44,6 +45,7 @@ "name": "e2e-anthropic-key", "value": "env.ANTHROPIC_API_KEY", "weight": 1, + "models": ["*"], "use_for_batch_api": true } ], diff --git a/.github/workflows/configs/withobservability/config.json b/.github/workflows/configs/withobservability/config.json index 82c60b8b4a..5050b69093 100644 --- a/.github/workflows/configs/withobservability/config.json +++ b/.github/workflows/configs/withobservability/config.json @@ -21,7 +21,7 @@ "config": { "service_name": "bifrost", "collector_url": "http://localhost:4318/v1/traces", - "trace_type": "otel", + "trace_type": "genai_extension", "protocol": "http" } } diff --git a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json index 600267db03..5c6b59fe9a 100644 --- a/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json +++ b/.github/workflows/configs/withpostgresmcpclientsinconfig/config.json @@ -7,7 +7,6 @@ ], "disable_content_logging": false, "drop_excess_requests": false, - "enable_litellm_fallbacks": false, "enable_logging": true, "enforce_auth_on_inference": true, "initial_pool_size": 300, @@ -41,18 +40,16 @@ "mcp": { "client_configs": [ { - "id": "weather-mcp-server", "name": "WeatherService", "connection_type": "http", - "http_url": "http://localhost:8080/mcp", - "is_enabled": true + "client_id": "weather-mcp-server", + "connection_string": "http://localhost:8080/mcp" }, { - "id": "calendar-mcp-server", "name": "CalendarService", "connection_type": "http", - "http_url": "http://localhost:8081/mcp", - "is_enabled": true + "client_id": "calendar-mcp-server", + "connection_string": "http://localhost:8081/mcp" } ] }, @@ -88,6 +85,12 @@ "provider_configs": [ { "provider": "openai", + "allowed_models": [ + "*" + ], + "key_ids": [ + "*" + ], "weight": 1.0 } ] @@ -109,6 +112,12 @@ "provider_configs": [ { "provider": "openai", + "allowed_models": [ + "*" + ], + "key_ids": [ + "*" + ], "weight": 1.0 } ] @@ -130,7 +139,10 @@ { "name": "openai-primary", "value": "env.OPENAI_API_KEY", - "weight": 1 + "weight": 1, + "models": [ + "*" + ] } ] } diff --git a/.github/workflows/configs/withsemanticcache/config.json b/.github/workflows/configs/withsemanticcache/config.json index 108c3dc15b..90fb65f670 100644 --- a/.github/workflows/configs/withsemanticcache/config.json +++ b/.github/workflows/configs/withsemanticcache/config.json @@ -13,6 +13,7 @@ "enabled": true, "name": "semantic_cache", "config": { + "dimension": 1, "vector_store_namespace": "test" } } diff --git a/.github/workflows/dependabot-alerts.yml b/.github/workflows/dependabot-alerts.yml index f92d41280c..0fc9c66801 100644 --- a/.github/workflows/dependabot-alerts.yml +++ b/.github/workflows/dependabot-alerts.yml @@ -12,10 +12,12 @@ jobs: create-issues: runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 - name: Create issues from Dependabot alerts env: diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index 60d8715ebc..7955983088 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -16,10 +16,15 @@ jobs: dependency-review: runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.deps.dev:443 + api.github.com:443 + api.securityscorecards.dev:443 + github.com:443 - name: 'Checkout Repository' uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4.3.1 diff --git a/.github/workflows/docs-validation.yml b/.github/workflows/docs-validation.yml index 772fd50f0a..d10dacfcbb 100644 --- a/.github/workflows/docs-validation.yml +++ b/.github/workflows/docs-validation.yml @@ -17,10 +17,18 @@ jobs: name: Check Broken Links runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + github.com:443 + nodejs.org:443 + ph.mintlify.com:443 + registry.npmjs.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 diff --git a/.github/workflows/e2e-tests.yml b/.github/workflows/e2e-tests.yml index ec3ffbddca..e48994c743 100644 --- a/.github/workflows/e2e-tests.yml +++ b/.github/workflows/e2e-tests.yml @@ -32,7 +32,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 diff --git a/.github/workflows/helm-release.yml b/.github/workflows/helm-release.yml index 69128ae643..bfeb83bb39 100644 --- a/.github/workflows/helm-release.yml +++ b/.github/workflows/helm-release.yml @@ -5,21 +5,31 @@ on: branches: - main paths: - - 'helm-charts/bifrost/**' - - '.github/workflows/helm-release.yml' + - "helm-charts/bifrost/**" + - ".github/workflows/helm-release.yml" workflow_dispatch: permissions: - contents: write + contents: write jobs: release: runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + get.helm.sh:443 + github.com:443 + maximhq.github.io:443 + proxy.golang.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 + sum.golang.org:443 + uploads.github.com:443 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -36,6 +46,11 @@ jobs: with: version: v4.0.0 + - name: Set up Go + uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 + with: + go-version: "1.26.2" + - name: Run chart-testing (lint) run: | helm lint helm-charts/bifrost @@ -50,6 +65,11 @@ jobs: chmod +x .github/workflows/scripts/validate-helm-config-fields.sh .github/workflows/scripts/validate-helm-config-fields.sh + - name: Validate Go ↔ config.schema.json ↔ helm-chart sync (schemasync) + run: | + chmod +x .github/workflows/scripts/validate-schema-sync.sh + .github/workflows/scripts/validate-schema-sync.sh + - name: Get chart version id: chart-version run: | @@ -98,12 +118,12 @@ jobs: - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 - if: github.ref == 'refs/heads/main' + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/v1.5.0' with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./helm-charts destination_dir: helm-charts keep_files: false enable_jekyll: false - user_name: 'github-actions[bot]' - user_email: 'github-actions[bot]@users.noreply.github.com' + user_name: "github-actions[bot]" + user_email: "github-actions[bot]@users.noreply.github.com" diff --git a/.github/workflows/openapi-bundle.yml b/.github/workflows/openapi-bundle.yml index 7cd8c232e8..44fe44b779 100644 --- a/.github/workflows/openapi-bundle.yml +++ b/.github/workflows/openapi-bundle.yml @@ -20,10 +20,14 @@ jobs: name: Bundle OpenAPI Spec runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + files.pythonhosted.org:443 + github.com:443 + pypi.org:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 diff --git a/.github/workflows/pr-tests.yml b/.github/workflows/pr-tests.yml index 3fbdaa4203..adb43b95f8 100644 --- a/.github/workflows/pr-tests.yml +++ b/.github/workflows/pr-tests.yml @@ -77,7 +77,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 diff --git a/.github/workflows/release-cli.yml b/.github/workflows/release-cli.yml index f143fbbcfe..b843195885 100644 --- a/.github/workflows/release-cli.yml +++ b/.github/workflows/release-cli.yml @@ -4,7 +4,7 @@ on: push: branches: - main - + # Prevent concurrent runs concurrency: group: release-cli @@ -20,10 +20,12 @@ jobs: version: ${{ steps.get-version.outputs.version }} tag_exists: ${{ steps.check-tag.outputs.exists }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + github.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -65,7 +67,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Run CLI tests working-directory: cli @@ -95,7 +97,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Configure Git run: | diff --git a/.github/workflows/release-pipeline.yml b/.github/workflows/release-pipeline.yml index 90bb9c86f3..666af37db8 100644 --- a/.github/workflows/release-pipeline.yml +++ b/.github/workflows/release-pipeline.yml @@ -20,10 +20,11 @@ jobs: outputs: should-skip: ${{ steps.check.outputs.should-skip }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: >+ - name: Check if pipeline should be skipped id: check @@ -54,10 +55,21 @@ jobs: framework-version: ${{ steps.detect.outputs.framework-version }} transport-version: ${{ steps.detect.outputs.transport-version }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + _http._tcp.azure.archive.ubuntu.com:443 + _https._tcp.esm.ubuntu.com:443 + _https._tcp.motd.ubuntu.com:443 + _https._tcp.packages.microsoft.com:443 + azure.archive.ubuntu.com:80 + dl.google.com:443 + esm.ubuntu.com:443 + github.com:443 + packages.microsoft.com:443 + registry.hub.docker.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -74,568 +86,9 @@ jobs: id: detect run: ./.github/workflows/scripts/detect-all-changes.sh "auto" - # Run all tests in parallel before any releases - test-core: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.core-needs-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 - with: - node-version: "25" - - - name: Run core tests - env: - MAXIM_API_KEY: ${{ secrets.MAXIM_API_KEY }} - MAXIM_LOGGER_ID: ${{ secrets.MAXIM_LOG_REPO_ID }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_SESSION_TOKEN: ${{ secrets.AWS_SESSION_TOKEN }} - AWS_ARN: ${{ secrets.AWS_ARN }} - BEDROCK_API_KEY: ${{ secrets.BEDROCK_API_KEY }} - AZURE_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} - AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} - AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} - MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} - OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} - PARASAIL_API_KEY: ${{ secrets.PARASAIL_API_KEY }} - ELEVENLABS_API_KEY: ${{ secrets.ELEVENLABS_API_KEY }} - PERPLEXITY_API_KEY: ${{ secrets.PERPLEXITY_API_KEY }} - SGL_API_KEY: ${{ secrets.SGL_API_KEY }} - CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} - VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }} - VERTEX_PROJECT_ID: ${{ secrets.VERTEX_PROJECT_ID }} - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} - AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }} - AWS_BEDROCK_ROLE_ARN: ${{ secrets.AWS_BEDROCK_ROLE_ARN }} - BIFROST_ENCRYPTION_KEY: ${{ secrets.BIFROST_ENCRYPTION_KEY }} - run: ./.github/workflows/scripts/test-core.sh - - # Approval gate for flaky test-core failures - # If test-core fails (often due to flaky provider API calls), this job waits for manual approval - # to continue the release pipeline without requiring a full re-run - approve-flaky-test-core: - needs: [check-skip, detect-changes, test-core] - if: | - always() && - needs.check-skip.outputs.should-skip != 'true' && - needs.detect-changes.outputs.core-needs-release == 'true' && - needs.test-core.result == 'failure' - runs-on: ubuntu-latest - environment: - name: flaky-test-override - url: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} - outputs: - approved: ${{ steps.approve.outputs.approved }} - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Display failed test info - run: | - echo "::warning::test-core failed. Review the logs to determine if this is a flaky test." - echo "If this is a known flaky test (e.g., provider API timeout), approve to continue." - echo "If this is a real failure, reject and fix the issue." - - name: Mark as approved - id: approve - run: echo "approved=true" >> $GITHUB_OUTPUT - - test-framework: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.framework-needs-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Docker Compose - run: | - docker --version - if ! docker compose version >/dev/null 2>&1; then - echo "Installing Docker Compose..." - sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose - sudo chmod +x /usr/local/bin/docker-compose - docker-compose --version - else - echo "Docker Compose plugin is available" - docker compose version - fi - - - name: Run framework tests - env: - MAXIM_API_KEY: ${{ secrets.MAXIM_API_KEY }} - MAXIM_LOGGER_ID: ${{ secrets.MAXIM_LOG_REPO_ID }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_SESSION_TOKEN: ${{ secrets.AWS_SESSION_TOKEN }} - AWS_ARN: ${{ secrets.AWS_ARN }} - BEDROCK_API_KEY: ${{ secrets.BEDROCK_API_KEY }} - AZURE_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} - AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} - AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} - MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} - OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} - PARASAIL_API_KEY: ${{ secrets.PARASAIL_API_KEY }} - ELEVENLABS_API_KEY: ${{ secrets.ELEVENLABS_API_KEY }} - PERPLEXITY_API_KEY: ${{ secrets.PERPLEXITY_API_KEY }} - SGL_API_KEY: ${{ secrets.SGL_API_KEY }} - CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} - VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }} - VERTEX_PROJECT_ID: ${{ secrets.VERTEX_PROJECT_ID }} - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} - BIFROST_ENCRYPTION_KEY: ${{ secrets.BIFROST_ENCRYPTION_KEY }} - run: ./.github/workflows/scripts/test-framework.sh - - test-plugins: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.plugins-need-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Install jq - run: | - sudo apt-get update - sudo apt-get install -y jq - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Docker Compose - run: | - docker --version - if ! docker compose version >/dev/null 2>&1; then - echo "Installing Docker Compose..." - sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose - sudo chmod +x /usr/local/bin/docker-compose - docker-compose --version - else - echo "Docker Compose plugin is available" - docker compose version - fi - - - name: Run plugin tests - env: - MAXIM_API_KEY: ${{ secrets.MAXIM_API_KEY }} - MAXIM_LOGGER_ID: ${{ secrets.MAXIM_LOG_REPO_ID }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_SESSION_TOKEN: ${{ secrets.AWS_SESSION_TOKEN }} - AWS_ARN: ${{ secrets.AWS_ARN }} - BEDROCK_API_KEY: ${{ secrets.BEDROCK_API_KEY }} - AZURE_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} - AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} - AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} - AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} - AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} - AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} - MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} - OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} - PARASAIL_API_KEY: ${{ secrets.PARASAIL_API_KEY }} - ELEVENLABS_API_KEY: ${{ secrets.ELEVENLABS_API_KEY }} - PERPLEXITY_API_KEY: ${{ secrets.PERPLEXITY_API_KEY }} - SGL_API_KEY: ${{ secrets.SGL_API_KEY }} - CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} - VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }} - VERTEX_PROJECT_ID: ${{ secrets.VERTEX_PROJECT_ID }} - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} - BIFROST_ENCRYPTION_KEY: ${{ secrets.BIFROST_ENCRYPTION_KEY }} - run: ./.github/workflows/scripts/test-all-plugins.sh - - test-bifrost-http: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.bifrost-http-needs-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 - with: - node-version: "25" - - - name: Set up Docker Compose - run: | - docker --version - if ! docker compose version >/dev/null 2>&1; then - echo "Installing Docker Compose..." - sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose - sudo chmod +x /usr/local/bin/docker-compose - docker-compose --version - else - echo "Docker Compose plugin is available" - docker compose version - fi - - - name: Run bifrost-http tests - env: - MAXIM_API_KEY: ${{ secrets.MAXIM_API_KEY }} - MAXIM_LOGGER_ID: ${{ secrets.MAXIM_LOG_REPO_ID }} - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - BIFROST_ENCRYPTION_KEY: ${{ secrets.BIFROST_ENCRYPTION_KEY }} - run: ./.github/workflows/scripts/test-bifrost-http.sh - - # Migration tests - validates database migrations from previous versions - test-migrations: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.bifrost-http-needs-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 - with: - node-version: "25" - - - name: Set up Docker Compose - run: | - docker --version - if ! docker compose version >/dev/null 2>&1; then - echo "Installing Docker Compose..." - sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose - sudo chmod +x /usr/local/bin/docker-compose - docker-compose --version - else - echo "Docker Compose plugin is available" - docker compose version - fi - - - name: Run migration tests - run: | - chmod +x ./.github/workflows/scripts/run-migration-tests.sh - ./.github/workflows/scripts/run-migration-tests.sh postgres - - # E2E UI tests - validates UI with Playwright - test-e2e-ui: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.bifrost-http-needs-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 - with: - node-version: "25" - - - name: Set up Docker Compose - run: | - docker --version - if ! docker compose version >/dev/null 2>&1; then - echo "Installing Docker Compose..." - sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose - sudo chmod +x /usr/local/bin/docker-compose - docker-compose --version - else - echo "Docker Compose plugin is available" - docker compose version - fi - - - name: Run E2E UI tests - env: - MCP_SSE_HEADERS: ${{ secrets.MCP_SSE_HEADERS }} - run: ./.github/workflows/scripts/test-e2e-ui.sh - - - name: Upload Playwright artifacts - if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 - with: - name: playwright-report - path: | - tests/e2e/test-results/ - tests/e2e/playwright-report/ - retention-days: 30 - - # Docker image test - amd64 - test-docker-image-amd64: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 - with: - node-version: "25" - - - name: Install Newman - run: npm install -g newman newman-reporter-html - - - name: Setup Docker Buildx - uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0 - - - name: Test Docker image (amd64) - env: - CI: "1" - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} - VERTEX_PROJECT_ID: ${{ secrets.VERTEX_PROJECT_ID }} - VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }} - GOOGLE_LOCATION: ${{ secrets.GOOGLE_LOCATION }} - MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} - PERPLEXITY_API_KEY: ${{ secrets.PERPLEXITY_API_KEY }} - CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }} - OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} - PARASAIL_API_KEY: ${{ secrets.PARASAIL_API_KEY }} - ELEVENLABS_API_KEY: ${{ secrets.ELEVENLABS_API_KEY }} - FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} - HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} - XAI_API_KEY: ${{ secrets.XAI_API_KEY }} - REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} - AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} - AZURE_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} - AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_REGION: ${{ secrets.AWS_REGION }} - AWS_ARN: ${{ secrets.AWS_ARN }} - run: | - chmod +x ./.github/workflows/scripts/test-docker-image.sh - ./.github/workflows/scripts/test-docker-image.sh linux/amd64 - - - name: Upload Newman reports - if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 - with: - name: newman-reports-amd64 - path: tests/e2e/api/newman-reports/ - retention-days: 30 - - # Docker image test - arm64 - test-docker-image-arm64: - needs: [check-skip, detect-changes] - if: needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' - runs-on: ubuntu-24.04-arm - permissions: - contents: read - steps: - - name: Harden the runner (Audit all outbound calls) - uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 - with: - egress-policy: audit - - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - fetch-tags: true - - - name: Set up Go - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 - with: - go-version: "1.26.2" - - - name: Set up Node.js - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 - with: - node-version: "25" - - - name: Install Newman - run: npm install -g newman newman-reporter-html - - - name: Setup Docker Buildx - uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3.12.0 - - - name: Test Docker image (arm64) - env: - CI: "1" - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} - GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} - VERTEX_PROJECT_ID: ${{ secrets.VERTEX_PROJECT_ID }} - VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }} - GOOGLE_LOCATION: ${{ secrets.GOOGLE_LOCATION }} - MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} - COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} - GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }} - PERPLEXITY_API_KEY: ${{ secrets.PERPLEXITY_API_KEY }} - CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }} - OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} - PARASAIL_API_KEY: ${{ secrets.PARASAIL_API_KEY }} - ELEVENLABS_API_KEY: ${{ secrets.ELEVENLABS_API_KEY }} - FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} - HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} - XAI_API_KEY: ${{ secrets.XAI_API_KEY }} - REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} - AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} - AZURE_ENDPOINT: ${{ secrets.AZURE_ENDPOINT }} - AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - AWS_REGION: ${{ secrets.AWS_REGION }} - AWS_ARN: ${{ secrets.AWS_ARN }} - run: | - chmod +x ./.github/workflows/scripts/test-docker-image.sh - ./.github/workflows/scripts/test-docker-image.sh linux/arm64 - - - name: Upload Newman reports - if: ${{ !cancelled() }} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 - with: - name: newman-reports-arm64 - path: tests/e2e/api/newman-reports/ - retention-days: 30 - core-release: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.core-needs-release == 'true' && (needs.test-core.result == 'success' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && (needs.test-framework.result == 'success' || needs.test-framework.result == 'skipped') && (needs.test-plugins.result == 'success' || needs.test-plugins.result == 'skipped') && (needs.test-bifrost-http.result == 'success' || needs.test-bifrost-http.result == 'skipped') && (needs.test-migrations.result == 'success' || needs.test-migrations.result == 'skipped') && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped')" + needs: [check-skip, detect-changes] + if: "needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.core-needs-release == 'true'" runs-on: ubuntu-latest permissions: contents: write @@ -643,10 +96,15 @@ jobs: success: ${{ steps.release.outputs.success }} version: ${{ needs.detect-changes.outputs.core-version }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + github.com:443 + nodejs.org:443 + release-assets.githubusercontent.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -658,7 +116,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 @@ -708,25 +166,12 @@ jobs: AWS_BEDROCK_ROLE_ARN: ${{ secrets.AWS_BEDROCK_ROLE_ARN }} REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} REPLICATE_OWNER: ${{ secrets.REPLICATE_OWNER }} - RUNWAY_API_KEY : ${{ secrets.RUNWAY_API_KEY }} + RUNWAY_API_KEY: ${{ secrets.RUNWAY_API_KEY }} run: ./.github/workflows/scripts/release-core.sh "${{ needs.detect-changes.outputs.core-version }}" framework-release: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - core-release, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.framework-needs-release == 'true' && (needs.test-core.result == 'success' || needs.test-core.result == 'skipped' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && needs.test-framework.result == 'success' && (needs.test-plugins.result == 'success' || needs.test-plugins.result == 'skipped') && (needs.test-bifrost-http.result == 'success' || needs.test-bifrost-http.result == 'skipped') && (needs.test-migrations.result == 'success' || needs.test-migrations.result == 'skipped') && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped') && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped')" + needs: [check-skip, detect-changes, core-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.framework-needs-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped')" runs-on: ubuntu-latest permissions: contents: write @@ -734,10 +179,17 @@ jobs: success: ${{ steps.release.outputs.success }} version: ${{ needs.detect-changes.outputs.framework-version }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + github.com:443 + proxy.golang.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 + sum.golang.org:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -749,7 +201,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Configure Git run: | @@ -807,36 +259,37 @@ jobs: HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} REPLICATE_OWNER: ${{ secrets.REPLICATE_OWNER }} - RUNWAY_API_KEY : ${{ secrets.RUNWAY_API_KEY }} + RUNWAY_API_KEY: ${{ secrets.RUNWAY_API_KEY }} run: ./.github/workflows/scripts/release-framework.sh "${{ needs.detect-changes.outputs.framework-version }}" plugins-release: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - core-release, - framework-release, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.plugins-need-release == 'true' && (needs.test-core.result == 'success' || needs.test-core.result == 'skipped' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && (needs.test-framework.result == 'success' || needs.test-framework.result == 'skipped') && needs.test-plugins.result == 'success' && (needs.test-bifrost-http.result == 'success' || needs.test-bifrost-http.result == 'skipped') && (needs.test-migrations.result == 'success' || needs.test-migrations.result == 'skipped') && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped') && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped')" + needs: [check-skip, detect-changes, core-release, framework-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.plugins-need-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped')" runs-on: ubuntu-latest permissions: contents: write outputs: success: ${{ steps.release.outputs.success }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + _http._tcp.azure.archive.ubuntu.com:443 + _https._tcp.esm.ubuntu.com:443 + _https._tcp.motd.ubuntu.com:443 + _https._tcp.packages.microsoft.com:443 + api.github.com:443 + azure.archive.ubuntu.com:80 + esm.ubuntu.com:443 + github.com:443 + nodejs.org:443 + packages.microsoft.com:443 + proxy.golang.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 + sum.golang.org:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -853,7 +306,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 @@ -916,38 +369,34 @@ jobs: HUGGING_FACE_API_KEY: ${{ secrets.HUGGING_FACE_API_KEY }} REPLICATE_API_KEY: ${{ secrets.REPLICATE_API_KEY }} REPLICATE_OWNER: ${{ secrets.REPLICATE_OWNER }} - RUNWAY_API_KEY : ${{ secrets.RUNWAY_API_KEY }} + RUNWAY_API_KEY: ${{ secrets.RUNWAY_API_KEY }} run: ./.github/workflows/scripts/release-all-plugins.sh '${{ needs.detect-changes.outputs.changed-plugins }}' # Prep: update dependencies, validate build, commit/push bifrost-http-prep: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - core-release, - framework-release, - plugins-release, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.bifrost-http-needs-release == 'true' && (needs.test-core.result == 'success' || needs.test-core.result == 'skipped' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && (needs.test-framework.result == 'success' || needs.test-framework.result == 'skipped') && (needs.test-plugins.result == 'success' || needs.test-plugins.result == 'skipped') && needs.test-bifrost-http.result == 'success' && needs.test-migrations.result == 'success' && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped') && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped')" + needs: [check-skip, detect-changes, core-release, framework-release, plugins-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.bifrost-http-needs-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped')" runs-on: ubuntu-latest permissions: contents: write outputs: success: ${{ steps.prep.outputs.success }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + nodejs.org:443 + proxy.golang.org:443 + registry.npmjs.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 + sum.golang.org:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -959,7 +408,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 @@ -985,10 +434,30 @@ jobs: permissions: contents: read steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + 7defe2860d5ee49a1e667e1eeea34b25.r2.cloudflarestorage.com:443 + _http._tcp.azure.archive.ubuntu.com:443 + _https._tcp.esm.ubuntu.com:443 + _https._tcp.motd.ubuntu.com:443 + _https._tcp.packages.microsoft.com:443 + api.github.com:443 + azure.archive.ubuntu.com:80 + esm.ubuntu.com:443 + files.pythonhosted.org:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + nodejs.org:443 + packages.microsoft.com:443 + proxy.golang.org:443 + pypi.org:443 + registry.npmjs.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1003,7 +472,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 @@ -1041,10 +510,23 @@ jobs: permissions: contents: read steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + 7defe2860d5ee49a1e667e1eeea34b25.r2.cloudflarestorage.com:443 + api.github.com:443 + files.pythonhosted.org:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + nodejs.org:443 + proxy.golang.org:443 + pypi.org:443 + registry.npmjs.org:443 + release-assets.githubusercontent.com:443 + storage.googleapis.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1059,7 +541,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Set up Node.js uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 @@ -1097,10 +579,16 @@ jobs: success: ${{ steps.release.outputs.success }} version: ${{ needs.detect-changes.outputs.transport-version }} steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + 7defe2860d5ee49a1e667e1eeea34b25.r2.cloudflarestorage.com:443 + api.github.com:443 + files.pythonhosted.org:443 + github.com:443 + pypi.org:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1129,24 +617,8 @@ jobs: # Docker build amd64 docker-build-amd64: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - core-release, - framework-release, - plugins-release, - bifrost-http-release, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' && (needs.test-core.result == 'success' || needs.test-core.result == 'skipped' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && (needs.test-framework.result == 'success' || needs.test-framework.result == 'skipped') && (needs.test-plugins.result == 'success' || needs.test-plugins.result == 'skipped') && (needs.test-bifrost-http.result == 'success' || needs.test-bifrost-http.result == 'skipped') && (needs.test-migrations.result == 'success' || needs.test-migrations.result == 'skipped') && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped') && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped') && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" + needs: [check-skip, detect-changes, core-release, framework-release, plugins-release, bifrost-http-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped') && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" runs-on: ubuntu-latest permissions: contents: write @@ -1155,10 +627,21 @@ jobs: ACCOUNT: maximhq IMAGE_NAME: bifrost steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + auth.docker.io:443 + dl-cdn.alpinelinux.org:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + production.cloudflare.docker.com:443 + proxy.golang.org:443 + registry-1.docker.io:443 + registry.npmjs.org:443 + storage.googleapis.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1205,24 +688,8 @@ jobs: # Docker build arm64 docker-build-arm64: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - core-release, - framework-release, - plugins-release, - bifrost-http-release, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' && (needs.test-core.result == 'success' || needs.test-core.result == 'skipped' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && (needs.test-framework.result == 'success' || needs.test-framework.result == 'skipped') && (needs.test-plugins.result == 'success' || needs.test-plugins.result == 'skipped') && (needs.test-bifrost-http.result == 'success' || needs.test-bifrost-http.result == 'skipped') && (needs.test-migrations.result == 'success' || needs.test-migrations.result == 'skipped') && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped') && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped') && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" + needs: [check-skip, detect-changes, core-release, framework-release, plugins-release, bifrost-http-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && needs.detect-changes.outputs.docker-needs-release == 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped') && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" runs-on: ubuntu-24.04-arm permissions: contents: write @@ -1231,10 +698,21 @@ jobs: ACCOUNT: maximhq IMAGE_NAME: bifrost steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + auth.docker.io:443 + dl-cdn.alpinelinux.org:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + production.cloudflare.docker.com:443 + proxy.golang.org:443 + registry-1.docker.io:443 + registry.npmjs.org:443 + storage.googleapis.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1289,10 +767,15 @@ jobs: ACCOUNT: maximhq IMAGE_NAME: bifrost steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + auth.docker.io:443 + github.com:443 + production.cloudflare.docker.com:443 + registry-1.docker.io:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1309,32 +792,18 @@ jobs: # Push Mintlify changelog push-mintlify-changelog: - needs: - [ - check-skip, - detect-changes, - test-core, - approve-flaky-test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, - core-release, - framework-release, - plugins-release, - bifrost-http-release, - ] - if: "always() && needs.check-skip.outputs.should-skip != 'true' && (needs.test-core.result == 'success' || needs.test-core.result == 'skipped' || (needs.test-core.result == 'failure' && needs.approve-flaky-test-core.result == 'success')) && (needs.test-framework.result == 'success' || needs.test-framework.result == 'skipped') && (needs.test-plugins.result == 'success' || needs.test-plugins.result == 'skipped') && (needs.test-bifrost-http.result == 'success' || needs.test-bifrost-http.result == 'skipped') && (needs.test-migrations.result == 'success' || needs.test-migrations.result == 'skipped') && (needs.test-docker-image-amd64.result == 'success' || needs.test-docker-image-amd64.result == 'skipped') && (needs.test-docker-image-arm64.result == 'success' || needs.test-docker-image-arm64.result == 'skipped') && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped') && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" + needs: [check-skip, detect-changes, core-release, framework-release, plugins-release, bifrost-http-release] + if: "always() && needs.check-skip.outputs.should-skip != 'true' && (needs.detect-changes.outputs.core-needs-release == 'false' || needs.core-release.result == 'success' || needs.core-release.result == 'skipped') && (needs.detect-changes.outputs.framework-needs-release == 'false' || needs.framework-release.result == 'success' || needs.framework-release.result == 'skipped') && (needs.detect-changes.outputs.plugins-need-release == 'false' || needs.plugins-release.result == 'success' || needs.plugins-release.result == 'skipped') && (needs.detect-changes.outputs.bifrost-http-needs-release == 'false' || needs.bifrost-http-release.result == 'success' || needs.bifrost-http-release.result == 'skipped')" runs-on: ubuntu-latest permissions: contents: write steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + github.com:443 - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -1353,13 +822,6 @@ jobs: [ check-skip, detect-changes, - test-core, - test-framework, - test-plugins, - test-bifrost-http, - test-migrations, - test-docker-image-amd64, - test-docker-image-arm64, core-release, framework-release, plugins-release, @@ -1369,10 +831,20 @@ jobs: if: "always() && needs.check-skip.outputs.should-skip != 'true'" runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + _http._tcp.azure.archive.ubuntu.com:443 + _https._tcp.esm.ubuntu.com:443 + _https._tcp.motd.ubuntu.com:443 + _https._tcp.packages.microsoft.com:443 + azure.archive.ubuntu.com:80 + discord.com:443 + dl.google.com:443 + esm.ubuntu.com:443 + packages.microsoft.com:443 - name: Install jq run: | diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 33206cdb3e..684d901c22 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -35,10 +35,23 @@ jobs: checks: read steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.deps.dev:443 + api.github.com:443 + api.osv.dev:443 + api.scorecard.dev:443 + auth.docker.io:443 + fulcio.sigstore.dev:443 + github.com:443 + index.docker.io:443 + oss-fuzz-build-logs.storage.googleapis.com:443 + rekor.sigstore.dev:443 + tuf-repo-cdn.sigstore.dev:443 + www.bestpractices.dev:443 - name: "Checkout code" uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 diff --git a/.github/workflows/scripts/detect-all-changes.sh b/.github/workflows/scripts/detect-all-changes.sh index ce6345315f..1f395e4107 100755 --- a/.github/workflows/scripts/detect-all-changes.sh +++ b/.github/workflows/scripts/detect-all-changes.sh @@ -47,8 +47,8 @@ else else if [[ "$CORE_VERSION" == *"-"* ]]; then # current_version has prerelease, so include all versions but prefer stable - ALL_TAGS=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') + ALL_TAGS=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) if [ -n "$STABLE_TAGS" ]; then # Get the highest stable version @@ -61,7 +61,7 @@ else fi else # VERSION has no prerelease, so only consider stable releases in same track - LATEST_CORE_TAG=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1) + LATEST_CORE_TAG=$(git tag -l "core/v${CORE_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1 || true) echo "latest core tag (stable only): $LATEST_CORE_TAG" fi PREVIOUS_CORE_VERSION=${LATEST_CORE_TAG#core/v} @@ -88,17 +88,26 @@ else FRAMEWORK_MAJOR_MINOR=$(echo "$FRAMEWORK_BASE_VERSION" | cut -d. -f1,2) echo " 🔍 Checking track: ${FRAMEWORK_MAJOR_MINOR}.x" - ALL_TAGS=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-') - PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) LATEST_FRAMEWORK_TAG="" - if [ -n "$STABLE_TAGS" ]; then - LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) - echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + if [[ "$FRAMEWORK_VERSION" == *"-"* ]]; then + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_FRAMEWORK_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest framework tag (stable preferred): $LATEST_FRAMEWORK_TAG" + else + # No stable versions, get highest prerelease + LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" + fi else - LATEST_FRAMEWORK_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) - echo "latest framework tag (prerelease only): $LATEST_FRAMEWORK_TAG" - fi + # VERSION has no prerelease, so only consider stable releases in same track + LATEST_FRAMEWORK_TAG=$(git tag -l "framework/v${FRAMEWORK_MAJOR_MINOR}.*" | grep -v '\-' | sort -V | tail -1 || true) + echo "latest framework tag (stable only): $LATEST_FRAMEWORK_TAG" + fi if [ -z "$LATEST_FRAMEWORK_TAG" ]; then echo " ✅ First framework release in track ${FRAMEWORK_MAJOR_MINOR}.x: $FRAMEWORK_VERSION" FRAMEWORK_NEEDS_RELEASE="true" @@ -153,20 +162,20 @@ for plugin_dir in plugins/*/; do echo " 🔍 Checking track: ${plugin_major_minor}.x" if [[ "$current_version" == *"-"* ]]; then - # current_version has prerelease, so include all versions but prefer stable - ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | sort -V) - STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) - PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) - - if [ -n "$STABLE_TAGS" ]; then - # Get the highest stable version - LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) - echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" - else - # No stable versions, get highest prerelease - LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) - echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" - fi + # current_version has prerelease, so include all versions but prefer stable + ALL_TAGS=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | sort -V) + STABLE_TAGS=$(echo "$ALL_TAGS" | grep -v '\-' || true) + PRERELEASE_TAGS=$(echo "$ALL_TAGS" | grep '\-' || true) + + if [ -n "$STABLE_TAGS" ]; then + # Get the highest stable version + LATEST_PLUGIN_TAG=$(echo "$STABLE_TAGS" | tail -1) + echo "latest plugin tag (stable preferred): $LATEST_PLUGIN_TAG" + else + # No stable versions, get highest prerelease + LATEST_PLUGIN_TAG=$(echo "$PRERELEASE_TAGS" | tail -1) + echo "latest plugin tag (prerelease only): $LATEST_PLUGIN_TAG" + fi else # VERSION has no prerelease, so only consider stable releases in same track LATEST_PLUGIN_TAG=$(git tag -l "plugins/${plugin_name}/v${plugin_major_minor}.*" | grep -v '\-' | sort -V | tail -1 || true) diff --git a/.github/workflows/scripts/release-bifrost-http-prep.sh b/.github/workflows/scripts/release-bifrost-http-prep.sh index c5eff8fc52..5983453755 100755 --- a/.github/workflows/scripts/release-bifrost-http-prep.sh +++ b/.github/workflows/scripts/release-bifrost-http-prep.sh @@ -76,24 +76,32 @@ echo "🔧 Using plugin versions from version files for transport..." # Track which plugins are actually used by the transport cd transports + +# Normalize the local go.mod directive up front so prior-release artifacts +# (e.g. `go 1.26.2` written by earlier `go get` runs) don't trip GOTOOLCHAIN=local. +go mod edit -go=1.26.1 -toolchain=none + for plugin_name in "${!PLUGIN_VERSIONS[@]}"; do plugin_version="${PLUGIN_VERSIONS[$plugin_name]}" # Check if transport depends on this plugin if grep -q "github.com/maximhq/bifrost/plugins/$plugin_name" go.mod; then echo " 📦 Using $plugin_name plugin $plugin_version" - go_get_with_backoff "github.com/maximhq/bifrost/plugins/$plugin_name@$plugin_version" + # Textual require bump — skips loading the currently-declared version's go.mod + go mod edit -require="github.com/maximhq/bifrost/plugins/$plugin_name@$plugin_version" fi done # Also ensure core and framework are up to date echo " 🔧 Updating core to $CORE_VERSION" -go_get_with_backoff "github.com/maximhq/bifrost/core@$CORE_VERSION" +go mod edit -require="github.com/maximhq/bifrost/core@$CORE_VERSION" echo " 📦 Updating framework to $FRAMEWORK_VERSION" -go_get_with_backoff "github.com/maximhq/bifrost/framework@$FRAMEWORK_VERSION" +go mod edit -require="github.com/maximhq/bifrost/framework@$FRAMEWORK_VERSION" +# Re-normalize before tidy in case any edit reintroduced a toolchain line +go mod edit -go=1.26.1 -toolchain=none go mod tidy cd .. diff --git a/.github/workflows/scripts/run-migration-tests.sh b/.github/workflows/scripts/run-migration-tests.sh index f59cd7ef25..817c73c1c3 100755 --- a/.github/workflows/scripts/run-migration-tests.sh +++ b/.github/workflows/scripts/run-migration-tests.sh @@ -133,11 +133,15 @@ cleanup() { } trap cleanup EXIT -# Get previous N transport versions (excluding prereleases) +# Get previous N transport versions (excluding prereleases) plus explicitly tested prereleases get_previous_versions() { local count="${1:-3}" cd "$REPO_ROOT" - git tag -l "transports/v*" | grep -v -- "-" | sort -V | tail -n "$count" | sed 's|transports/||' + local stable + stable=$(git tag -l "transports/v*" | grep -v -- "-" | sort -V | tail -n "$count" | sed 's|transports/||') + # Explicitly include prerelease versions that need migration coverage + local prereleases="v1.5.0-prerelease1" + echo "$stable"$'\n'"$prereleases" | grep -v '^$' | sort -V | uniq } # Wait for bifrost to start @@ -339,6 +343,22 @@ run_postgres_sql() { -c "$sql" 2>/dev/null } +run_postgres_scalar() { + local sql="$1" + + local container + container=$(get_postgres_container) + + if [ -z "$container" ]; then + log_error "PostgreSQL container not found" + return 1 + fi + + docker exec "$container" \ + psql -U "$POSTGRES_USER" -d "$POSTGRES_DB" -t -A \ + -c "$sql" 2>/dev/null | tr -d '[:space:]' +} + run_postgres_sql_file() { local sql_file="$1" @@ -453,10 +473,10 @@ VALUES (1, 'migration-test-hash-abc123def456', $now, $now) ON CONFLICT DO NOTHING; -- governance_budgets (reset_duration is a string like "1d", "1h", etc.) -INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, created_at, updated_at, calendar_aligned) +INSERT INTO governance_budgets (id, max_limit, current_usage, reset_duration, last_reset, config_hash, calendar_aligned, created_at, updated_at) VALUES - ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', $now, $now, 0), - ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', $now, $now, 1) + ('budget-migration-test-1', 1000.00, 100.00, '1d', $now, 'budget-hash-001', false, $now, $now), + ('budget-migration-test-2', 5000.00, 250.00, '7d', $now, 'budget-hash-002', false, $now, $now) ON CONFLICT DO NOTHING; -- governance_rate_limits (flexible duration format with token_* and request_* columns) @@ -522,8 +542,8 @@ VALUES ('migration-test-lock', 'holder-migration-test-001', $future, $now) ON CONFLICT DO NOTHING; -- config_client (global client configuration) -INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, enable_litellm_fallbacks, config_hash, created_at, updated_at) -VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, true, 100, 10, 30, 'server', 10, false, 'client-config-hash-001', $now, $now) +INSERT INTO config_client (id, drop_excess_requests, prometheus_labels_json, allowed_origins_json, allowed_headers_json, header_filter_config_json, initial_pool_size, enable_logging, disable_content_logging, disable_db_pings_in_health, log_retention_days, enforce_governance_header, allow_direct_keys, max_request_body_size_mb, mcp_agent_depth, mcp_tool_execution_timeout, mcp_code_mode_binding_level, mcp_tool_sync_interval, compat_convert_text_to_chat, compat_convert_chat_to_responses, compat_should_drop_params, compat_should_convert_params, config_hash, created_at, updated_at) +VALUES (1, false, '["provider", "model"]', '["*"]', '["Authorization"]', '{}', 300, true, false, false, 365, true, false, 100, 10, 30, 'server', 10, false, false, false, true, 'client-config-hash-001', $now, $now) ON CONFLICT DO NOTHING; -- governance_config (key-value config table) @@ -623,12 +643,9 @@ CROSS JOIN config_keys ck WHERE vpc.virtual_key_id = 'vk-migration-test-1' AND ck.name = 'migration-test-key-openai' ON CONFLICT DO NOTHING; --- governance_virtual_key_mcp_configs (references virtual_keys and mcp_clients) --- We need to reference the mcp_client by its internal ID, so use a subquery -INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) -SELECT 'vk-migration-test-1', id, '["tool1"]' -FROM config_mcp_clients WHERE client_id = 'mcp-migration-test-001' -ON CONFLICT DO NOTHING; +-- governance_virtual_key_mcp_configs: handled dynamically after config_mcp_clients is inserted +-- (see generate_mcp_clients_insert_postgres/sqlite) so the subquery finds the MCP client row. +-- Both test VKs are covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. -- sessions (id is auto-increment integer, not a string) INSERT INTO sessions (token, expires_at, created_at, updated_at) @@ -707,6 +724,7 @@ append_dynamic_mcp_clients_insert() { generate_prompt_repo_tables_insert_postgres "$now" "$faker_sql" generate_model_parameters_insert_postgres "$now" "$faker_sql" generate_routing_targets_insert_postgres "$now" "$faker_sql" + generate_pricing_overrides_insert_postgres "$now" "$faker_sql" append_dynamic_columns_postgres "$now" "$past" "$faker_sql" else now="datetime('now')" @@ -717,6 +735,7 @@ append_dynamic_mcp_clients_insert() { generate_prompt_repo_tables_insert_sqlite "$now" "$faker_sql" "$config_db" generate_model_parameters_insert_sqlite "$now" "$faker_sql" "$config_db" generate_routing_targets_insert_sqlite "$now" "$faker_sql" "$config_db" + generate_pricing_overrides_insert_sqlite "$now" "$faker_sql" "$config_db" append_dynamic_columns_sqlite "$now" "$past" "$faker_sql" "$config_db" fi } @@ -822,6 +841,16 @@ append_dynamic_columns_postgres() { echo "UPDATE config_keys SET vllm_model_name = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" fi + # config_keys.ollama_url, sgl_url (added in v1.5.0-prerelease1) + if column_exists_postgres "config_keys" "ollama_url"; then + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + if column_exists_postgres "config_keys" "sgl_url"; then + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + # config_keys.encryption_status (added in v1.4.8) if column_exists_postgres "config_keys" "encryption_status"; then echo "UPDATE config_keys SET encryption_status = 'plain_text' WHERE name = 'migration-test-key-openai';" >> "$output_file" @@ -949,6 +978,17 @@ append_dynamic_columns_postgres() { echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" fi + # logs.image_edit_input, image_variation_input (added in v1.5.0-prerelease1) + if column_exists_postgres "logs" "image_edit_input"; then + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + if column_exists_postgres "logs" "image_variation_input"; then + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi if column_exists_postgres "logs" "video_list_output"; then echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" @@ -1190,6 +1230,61 @@ append_dynamic_columns_postgres() { echo "UPDATE governance_model_pricing SET code_interpreter_cost_per_session = NULL WHERE id = 2;" >> "$output_file" fi + # ------------------------------------------------------------------------- + # v1.5.0 columns - config store tables + # ------------------------------------------------------------------------- + + # config_client.mcp_disable_auto_tool_inject (added in v1.5.0) + if column_exists_postgres "config_client" "mcp_disable_auto_tool_inject"; then + echo "UPDATE config_client SET mcp_disable_auto_tool_inject = false WHERE id = 1;" >> "$output_file" + fi + + # config_client.whitelisted_routes_json (added in v1.5.0) + if column_exists_postgres "config_client" "whitelisted_routes_json"; then + echo "UPDATE config_client SET whitelisted_routes_json = '[]' WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0) + # vk-migration-test-1 has a key in the join table, so old behavior was restricted to that key -> allow_all_keys=false + # vk-migration-test-2 has no key rows, so old "empty=allow-all" semantics -> allow_all_keys=true + if column_exists_postgres "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = false WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = true WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.5.0 columns - log store tables + # ------------------------------------------------------------------------- + + # logs.plugin_logs (added in v1.5.0) + if column_exists_postgres "logs" "plugin_logs"; then + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + + # ------------------------------------------------------------------------- + # v1.4.19 columns + # ------------------------------------------------------------------------- + + # governance_model_pricing: context_length, max_input_tokens, max_output_tokens, architecture (added in v1.4.19, removed later) + if column_exists_postgres "governance_model_pricing" "context_length"; then + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "max_input_tokens"; then + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "max_output_tokens"; then + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "architecture"; then + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 2;" >> "$output_file" + fi + # ------------------------------------------------------------------------- # v1.4.17 columns # ------------------------------------------------------------------------- @@ -1305,6 +1400,28 @@ append_dynamic_columns_postgres() { echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-001';" >> "$output_file" echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-002';" >> "$output_file" fi + + # ------------------------------------------------------------------------- + # v1.4.22 columns - flex tier pricing and litellm fallbacks toggle + # ------------------------------------------------------------------------- + + # config_client.enable_litellm_fallbacks (added in v1.4.22) + if column_exists_postgres "config_client" "enable_litellm_fallbacks"; then + echo "UPDATE config_client SET enable_litellm_fallbacks = false WHERE id = 1;" >> "$output_file" + fi + + if column_exists_postgres "governance_model_pricing" "input_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "output_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_postgres "governance_model_pricing" "cache_read_input_token_cost_flex"; then + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" + fi } # Append dynamic column UPDATEs for columns that may not exist in older schemas (SQLite) @@ -1410,6 +1527,16 @@ append_dynamic_columns_sqlite() { echo "UPDATE config_keys SET vllm_model_name = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" fi + # config_keys.ollama_url, sgl_url (added in v1.5.0-prerelease1) + if column_exists_sqlite "$config_db" "config_keys" "ollama_url"; then + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET ollama_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "config_keys" "sgl_url"; then + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-openai';" >> "$output_file" + echo "UPDATE config_keys SET sgl_url = '' WHERE name = 'migration-test-key-anthropic';" >> "$output_file" + fi + # config_keys.encryption_status (added in v1.4.8) if column_exists_sqlite "$config_db" "config_keys" "encryption_status"; then echo "UPDATE config_keys SET encryption_status = 'plain_text' WHERE name = 'migration-test-key-openai';" >> "$output_file" @@ -1528,6 +1655,17 @@ append_dynamic_columns_sqlite() { echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_download_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + # logs.image_edit_input, image_variation_input (added in v1.5.0-prerelease1) + if column_exists_sqlite "$logs_db" "logs" "image_edit_input"; then + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_edit_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi + if column_exists_sqlite "$logs_db" "logs" "image_variation_input"; then + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET image_variation_input = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + fi echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-001';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET video_list_output = '' WHERE id = 'log-migration-test-003';" >> "$output_file" @@ -1755,6 +1893,58 @@ append_dynamic_columns_sqlite() { echo "UPDATE logs SET cached_read_tokens = 0 WHERE id = 'log-migration-test-002';" >> "$output_file" echo "UPDATE logs SET cached_read_tokens = 0 WHERE id = 'log-migration-test-003';" >> "$output_file" + # ------------------------------------------------------------------------- + # v1.5.0 columns - config store tables + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_client.mcp_disable_auto_tool_inject (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_client" "mcp_disable_auto_tool_inject"; then + echo "UPDATE config_client SET mcp_disable_auto_tool_inject = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_virtual_key_provider_configs.allow_all_keys (added in v1.5.0) + # vk-migration-test-1 has a key in the join table, so old behavior was restricted to that key -> allow_all_keys=false + # vk-migration-test-2 has no key rows, so old "empty=allow-all" semantics -> allow_all_keys=true + if column_exists_sqlite "$config_db" "governance_virtual_key_provider_configs" "allow_all_keys"; then + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 0 WHERE virtual_key_id = 'vk-migration-test-1';" >> "$output_file" + echo "UPDATE governance_virtual_key_provider_configs SET allow_all_keys = 1 WHERE virtual_key_id = 'vk-migration-test-2';" >> "$output_file" + fi + fi + + # ------------------------------------------------------------------------- + # v1.5.0 columns - log store tables (emitted unconditionally; fail silently on config_db) + # ------------------------------------------------------------------------- + + # logs.plugin_logs (added in v1.5.0) + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-001';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-002';" >> "$output_file" + echo "UPDATE logs SET plugin_logs = '' WHERE id = 'log-migration-test-003';" >> "$output_file" + + # ------------------------------------------------------------------------- + # v1.4.19 columns + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # governance_model_pricing: context_length, max_input_tokens, max_output_tokens, architecture (added in v1.4.19, removed later) + if column_exists_sqlite "$config_db" "governance_model_pricing" "context_length"; then + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET context_length = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "max_input_tokens"; then + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_input_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "max_output_tokens"; then + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET max_output_tokens = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "architecture"; then + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET architecture = NULL WHERE id = 2;" >> "$output_file" + fi + fi + # ------------------------------------------------------------------------- # v1.4.17 columns # ------------------------------------------------------------------------- @@ -1848,6 +2038,22 @@ append_dynamic_columns_sqlite() { echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_above_272k_tokens_priority = NULL WHERE id = 1;" >> "$output_file" echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_above_272k_tokens_priority = NULL WHERE id = 2;" >> "$output_file" fi + + # ------------------------------------------------------------------------- + # v1.4.22 columns - governance_model_pricing flex tier pricing + # ------------------------------------------------------------------------- + if column_exists_sqlite "$config_db" "governance_model_pricing" "input_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "output_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "cache_read_input_token_cost_flex"; then + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" + fi fi # ------------------------------------------------------------------------- @@ -1862,6 +2068,31 @@ append_dynamic_columns_sqlite() { # mcp_tool_logs.request_id (added in v1.4.21) echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-001';" >> "$output_file" echo "UPDATE mcp_tool_logs SET request_id = '' WHERE id = 'mcp-log-migration-002';" >> "$output_file" + + # ------------------------------------------------------------------------- + # v1.4.22 columns - flex tier pricing and litellm fallbacks toggle + # ------------------------------------------------------------------------- + + if [ -f "$config_db" ]; then + # config_client.enable_litellm_fallbacks (added in v1.4.22) + if column_exists_sqlite "$config_db" "config_client" "enable_litellm_fallbacks"; then + echo "UPDATE config_client SET enable_litellm_fallbacks = 0 WHERE id = 1;" >> "$output_file" + fi + + # governance_model_pricing flex tier columns (added in v1.4.22) + if column_exists_sqlite "$config_db" "governance_model_pricing" "input_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET input_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "output_cost_per_token_flex"; then + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET output_cost_per_token_flex = NULL WHERE id = 2;" >> "$output_file" + fi + if column_exists_sqlite "$config_db" "governance_model_pricing" "cache_read_input_token_cost_flex"; then + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 1;" >> "$output_file" + echo "UPDATE governance_model_pricing SET cache_read_input_token_cost_flex = NULL WHERE id = 2;" >> "$output_file" + fi + fi } # ============================================================================ @@ -1949,10 +2180,29 @@ generate_mcp_clients_insert_postgres() { vals="$vals, 'plain_text'" fi + # config_mcp_clients.allowed_extra_headers_json (added in v1.5.0) + if column_exists_postgres "config_mcp_clients" "allowed_extra_headers_json"; then + cols="$cols, allowed_extra_headers_json" + vals="$vals, '[]'" + fi + + # config_mcp_clients.allow_on_all_virtual_keys (added in v1.5.0) + if column_exists_postgres "config_mcp_clients" "allow_on_all_virtual_keys"; then + cols="$cols, allow_on_all_virtual_keys" + vals="$vals, false" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" echo "INSERT INTO config_mcp_clients ($cols) VALUES ($vals) ON CONFLICT DO NOTHING;" >> "$output_file" + + # governance_virtual_key_mcp_configs: link both test VKs to the test MCP client. + # Must run AFTER config_mcp_clients INSERT so the subquery finds the row. + # Both VKs covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. + echo "" >> "$output_file" + echo "-- governance_virtual_key_mcp_configs (dynamically generated after config_mcp_clients)" >> "$output_file" + echo "INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) SELECT vk.id, mc.id, '[\"tool1\"]' FROM governance_virtual_keys vk CROSS JOIN config_mcp_clients mc WHERE mc.client_id = 'mcp-migration-test-001' AND vk.id IN ('vk-migration-test-1', 'vk-migration-test-2') ON CONFLICT DO NOTHING;" >> "$output_file" } # Get columns that are auto-increment primary keys (don't need faker coverage) @@ -2163,10 +2413,29 @@ generate_mcp_clients_insert_sqlite() { vals="$vals, 'plain_text'" fi + # config_mcp_clients.allowed_extra_headers_json (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_mcp_clients" "allowed_extra_headers_json"; then + cols="$cols, allowed_extra_headers_json" + vals="$vals, '[]'" + fi + + # config_mcp_clients.allow_on_all_virtual_keys (added in v1.5.0) + if column_exists_sqlite "$config_db" "config_mcp_clients" "allow_on_all_virtual_keys"; then + cols="$cols, allow_on_all_virtual_keys" + vals="$vals, 0" + fi + # Append the dynamic INSERT to the output file echo "" >> "$output_file" echo "-- config_mcp_clients (MCP server configurations - dynamically generated based on schema)" >> "$output_file" echo "INSERT INTO config_mcp_clients ($cols) VALUES ($vals) ON CONFLICT DO NOTHING;" >> "$output_file" + + # governance_virtual_key_mcp_configs: link both test VKs to the test MCP client. + # Must run AFTER config_mcp_clients INSERT so the subquery finds the row. + # Both VKs covered to prevent migrationBackfillEmptyVirtualKeyConfigs from adding rows. + echo "" >> "$output_file" + echo "-- governance_virtual_key_mcp_configs (dynamically generated after config_mcp_clients)" >> "$output_file" + echo "INSERT INTO governance_virtual_key_mcp_configs (virtual_key_id, mcp_client_id, tools_to_execute) SELECT vk.id, mc.id, '[\"tool1\"]' FROM governance_virtual_keys vk CROSS JOIN config_mcp_clients mc WHERE mc.client_id = 'mcp-migration-test-001' AND vk.id IN ('vk-migration-test-1', 'vk-migration-test-2') ON CONFLICT DO NOTHING;" >> "$output_file" } # Generate async_jobs INSERT based on schema existence for PostgreSQL @@ -2395,6 +2664,49 @@ generate_routing_targets_insert_sqlite() { echo "INSERT INTO routing_targets (rule_id, provider, model, key_id, weight) VALUES ('rule-migration-test-2', NULL, NULL, NULL, 0.3) ON CONFLICT DO NOTHING;" >> "$output_file" } +# Generate governance_pricing_overrides INSERT for PostgreSQL +# This table was added in v1.5.0 as part of the custom pricing refactor. +# Two rows: one global (no FK deps) and one virtual_key-scoped (references vk-migration-test-1). +generate_pricing_overrides_insert_postgres() { + local now="$1" + local output_file="$2" + + # Check if the table exists + if ! column_exists_postgres "governance_pricing_overrides" "id"; then + return + fi + + echo "" >> "$output_file" + echo "-- governance_pricing_overrides (scoped pricing overrides - added in v1.5.0, dynamically generated)" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-001', 'Migration Test Override Global', 'global', NULL, NULL, NULL, 'exact', 'gpt-4', '[]', '{\"input_cost_per_token\": 0.00001}', 'po-hash-001', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-002', 'Migration Test Override VK', 'virtual_key', 'vk-migration-test-1', NULL, NULL, 'prefix', 'claude', '[]', '{\"output_cost_per_token\": 0.00002}', 'po-hash-002', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + +# Generate governance_pricing_overrides INSERT for SQLite +# This table was added in v1.5.0 as part of the custom pricing refactor. +generate_pricing_overrides_insert_sqlite() { + local now="$1" + local output_file="$2" + local config_db="$3" + + # Check if the table exists in the database + if [ ! -f "$config_db" ]; then + return + fi + + local table_exists + table_exists=$(sqlite3 "$config_db" "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='governance_pricing_overrides';" 2>/dev/null || echo "0") + + if [ "$table_exists" != "1" ]; then + return + fi + + echo "" >> "$output_file" + echo "-- governance_pricing_overrides (scoped pricing overrides - added in v1.5.0, dynamically generated)" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-001', 'Migration Test Override Global', 'global', NULL, NULL, NULL, 'exact', 'gpt-4', '[]', '{\"input_cost_per_token\": 0.00001}', 'po-hash-001', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" + echo "INSERT INTO governance_pricing_overrides (id, name, scope_kind, virtual_key_id, provider_id, provider_key_id, match_type, pattern, request_types_json, pricing_patch_json, config_hash, created_at, updated_at) VALUES ('pricing-override-migration-002', 'Migration Test Override VK', 'virtual_key', 'vk-migration-test-1', NULL, NULL, 'prefix', 'claude', '[]', '{\"output_cost_per_token\": 0.00002}', 'po-hash-002', $now, $now) ON CONFLICT DO NOTHING;" >> "$output_file" +} + # Validate faker column coverage for SQLite validate_faker_column_coverage_sqlite() { local faker_sql="$1" @@ -2613,6 +2925,7 @@ compare_postgres_snapshots() { # - network_config_json, concurrency_buffer_json, proxy_config_json, custom_provider_config_json: # JSON fields that get normalized with default values during migration # - budget_id, rate_limit_id: governance fields that may be reset or initialized during migrations + # - virtual_key_id, provider_config_id: new FK columns on governance_budgets (added by multi-budget migration) # - status, description: key validation runs after migration, updating these fields # for invalid/test keys (e.g., status becomes "list_models_failed") local ignore_columns="updated_at config_hash created_at models_json weight allowed_models network_config_json concurrency_buffer_json proxy_config_json custom_provider_config_json budget_id rate_limit_id status description" @@ -2669,6 +2982,24 @@ compare_postgres_snapshots() { if [ "$table" = "routing_rules" ]; then dropped_columns="$dropped_columns provider model" fi + # azure_deployments_json, vertex_deployments_json, bedrock_deployments_json, replicate_deployments_json + # (dropped from config_keys - migrated to provider-level deployment config) + if [ "$table" = "config_keys" ]; then + dropped_columns="$dropped_columns azure_deployments_json vertex_deployments_json bedrock_deployments_json replicate_deployments_json" + fi + # budget_id (dropped from governance_virtual_keys and governance_virtual_key_provider_configs + # in add_multi_budget_tables - ownership moved to governance_budgets.virtual_key_id/provider_config_id) + if [ "$table" = "governance_virtual_keys" ] || [ "$table" = "governance_virtual_key_provider_configs" ]; then + dropped_columns="$dropped_columns budget_id" + fi + # calendar_aligned (dropped from governance_budgets in add_multi_budget_tables - moved to governance_virtual_keys.calendar_aligned) + if [ "$table" = "governance_budgets" ]; then + dropped_columns="$dropped_columns calendar_aligned" + fi + # enable_litellm_fallbacks (dropped from config_client in latest cut - behavior moved elsewhere) + if [ "$table" = "config_client" ]; then + dropped_columns="$dropped_columns enable_litellm_fallbacks" + fi local before_col_array IFS=',' read -ra before_col_array <<< "$before_columns" @@ -2729,7 +3060,12 @@ compare_postgres_snapshots() { local col_idx=1 for col in "${before_col_array[@]}"; do # Skip columns that are expected to change - if [[ " $ignore_columns " == *" $col "* ]]; then + # virtual_key_id, provider_config_id: only ignore on governance_budgets (new FK columns from multi-budget migration) + local table_ignore_columns="$ignore_columns" + if [ "$table" = "governance_budgets" ]; then + table_ignore_columns="$table_ignore_columns virtual_key_id provider_config_id" + fi + if [[ " $table_ignore_columns " == *" $col "* ]]; then col_idx=$((col_idx + 1)) continue fi @@ -2820,6 +3156,84 @@ compare_postgres_snapshots() { # Validation Functions (simplified, uses snapshots) # ============================================================================ +# verify_budget_migration checks that the multi-budget FK migration correctly +# moved budget ownership from VK/ProviderConfig budget_id columns to +# governance_budgets.virtual_key_id / governance_budgets.provider_config_id +verify_budget_migration_postgres() { + log_info "Verifying budget migration (budget_id → virtual_key_id/provider_config_id)..." + local failed=0 + + # Check: budget-migration-test-1 was linked to vk-migration-test-1 via budget_id + # After migration, governance_budgets.virtual_key_id should be set + local vk_budget_count + vk_budget_count=$(run_postgres_scalar "SELECT COUNT(*) FROM governance_budgets WHERE id = 'budget-migration-test-1' AND virtual_key_id = 'vk-migration-test-1'") + if [ "$vk_budget_count" = "1" ]; then + log_info " VK budget migration: budget-migration-test-1 → vk-migration-test-1 ✓" + else + log_warn " VK budget migration: budget-migration-test-1 virtual_key_id not set (count=$vk_budget_count) — may be expected if old version didn't have budget_id on VK" + fi + + # Check: budget-migration-test-2 was linked to provider config via budget_id + # After migration, governance_budgets.provider_config_id should be set + local pc_budget_count + pc_budget_count=$(run_postgres_scalar "SELECT COUNT(*) FROM governance_budgets WHERE id = 'budget-migration-test-2' AND provider_config_id IS NOT NULL") + if [ "$pc_budget_count" = "1" ]; then + log_info " PC budget migration: budget-migration-test-2 → provider_config ✓" + else + log_warn " PC budget migration: budget-migration-test-2 provider_config_id not set (count=$pc_budget_count) — may be expected if old version didn't have budget_id on PC" + fi + + # Check: virtual_key_id and provider_config_id columns exist on governance_budgets + local has_vk_col + has_vk_col=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_budgets' AND column_name = 'virtual_key_id'") + if [ "$has_vk_col" = "1" ]; then + log_info " Column governance_budgets.virtual_key_id exists ✓" + else + log_error " Column governance_budgets.virtual_key_id MISSING!" + failed=1 + fi + + local has_pc_col + has_pc_col=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_budgets' AND column_name = 'provider_config_id'") + if [ "$has_pc_col" = "1" ]; then + log_info " Column governance_budgets.provider_config_id exists ✓" + else + log_error " Column governance_budgets.provider_config_id MISSING!" + failed=1 + fi + + # Check: budget_id column should be dropped from governance_virtual_keys + local vk_has_budget_id + vk_has_budget_id=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_virtual_keys' AND column_name = 'budget_id'") + if [ "$vk_has_budget_id" = "0" ]; then + log_info " Column governance_virtual_keys.budget_id dropped ✓" + else + log_error " Column governance_virtual_keys.budget_id still exists!" + failed=1 + fi + + # Check: budget_id column should be dropped from governance_virtual_key_provider_configs + local pc_has_budget_id + pc_has_budget_id=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'governance_virtual_key_provider_configs' AND column_name = 'budget_id'") + if [ "$pc_has_budget_id" = "0" ]; then + log_info " Column governance_virtual_key_provider_configs.budget_id dropped ✓" + else + log_error " Column governance_virtual_key_provider_configs.budget_id still exists!" + failed=1 + fi + + # Check: junction tables should not exist + local junction_vk + junction_vk=$(run_postgres_scalar "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = 'governance_virtual_key_budgets'") + if [ "$junction_vk" = "0" ]; then + log_info " Junction table governance_virtual_key_budgets dropped ✓" + else + log_warn " Junction table governance_virtual_key_budgets still exists (may not have existed in old version)" + fi + + return $failed +} + validate_postgres_data() { local before_snapshot="$1" local after_snapshot="$2" @@ -2983,7 +3397,7 @@ EOF --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$server_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$server_log" 120; then + if ! wait_for_bifrost "$server_log" 300; then log_error "Failed to start bifrost $version" cat "$server_log" 2>/dev/null || true stop_bifrost @@ -3024,7 +3438,7 @@ EOF "$current_binary" --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$current_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$current_log" 120; then + if ! wait_for_bifrost "$current_log" 300; then log_error "Current version failed to start after migrating from $version" cat "$current_log" stop_bifrost @@ -3069,6 +3483,13 @@ EOF return 1 fi + # STEP 6: Verify budget migration (budget_id → virtual_key_id/provider_config_id) + if ! verify_budget_migration_postgres; then + log_error "Budget migration verification failed after migration from $version" + stop_bifrost + return 1 + fi + stop_bifrost log_info "Migration from $version: SUCCESS" done @@ -3175,7 +3596,7 @@ EOF --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$server_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$server_log" 120; then + if ! wait_for_bifrost "$server_log" 300; then log_error "Failed to start bifrost $version" cat "$server_log" 2>/dev/null || true stop_bifrost @@ -3215,7 +3636,7 @@ EOF "$current_binary" --app-dir "$TEMP_DIR" --port "$BIFROST_PORT" > "$current_log" 2>&1 & BIFROST_PID=$! - if ! wait_for_bifrost "$current_log" 120; then + if ! wait_for_bifrost "$current_log" 300; then log_error "Current version failed to start after migrating from $version" cat "$current_log" stop_bifrost @@ -3289,4 +3710,4 @@ main() { exit $exit_code } -main "$@" +main "$@" \ No newline at end of file diff --git a/.github/workflows/scripts/schemasync/go.mod b/.github/workflows/scripts/schemasync/go.mod new file mode 100644 index 0000000000..0b8a2eee00 --- /dev/null +++ b/.github/workflows/scripts/schemasync/go.mod @@ -0,0 +1,10 @@ +module github.com/maximhq/bifrost/tools/schema-sync + +go 1.26.2 + +require golang.org/x/tools v0.30.0 + +require ( + golang.org/x/mod v0.23.0 // indirect + golang.org/x/sync v0.11.0 // indirect +) diff --git a/.github/workflows/scripts/schemasync/go.sum b/.github/workflows/scripts/schemasync/go.sum new file mode 100644 index 0000000000..68d0914a62 --- /dev/null +++ b/.github/workflows/scripts/schemasync/go.sum @@ -0,0 +1,8 @@ +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= diff --git a/.github/workflows/scripts/schemasync/main.go b/.github/workflows/scripts/schemasync/main.go new file mode 100644 index 0000000000..ff81e0902f --- /dev/null +++ b/.github/workflows/scripts/schemasync/main.go @@ -0,0 +1,1075 @@ +// schemasync validates that Bifrost Go config types stay in sync with +// transports/config.schema.json. +// +// Starting from a configured entry-point type (default: ConfigData in +// transports/bifrost-http/lib), it recursively walks every nested struct +// field via go/types. For each field it verifies: +// +// 1. The json:"X" tag has a corresponding property in config.schema.json at +// the propagated schema path (handling $ref, allOf, oneOf, if/then/else). +// 2. If the field's Go type is a named string type with const declarations, +// the set of Go constant values matches the schema's enum array. +// +// Exit 0 on full agreement, 1 on any mismatch. +package main + +import ( + "encoding/json" + "flag" + "fmt" + "go/constant" + "go/types" + "os" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +type entrypoint struct { + pkg string // Go import path + typeName string // exported type name + schemaPath string // JSON pointer path in config.schema.json (e.g. "/properties") + moduleDir string // directory (relative to --pkg-root) that contains the go.mod +} + +var entrypoints = []entrypoint{ + { + pkg: "github.com/maximhq/bifrost/transports/bifrost-http/lib", + typeName: "ConfigData", + schemaPath: "", // root schema node — collectProperties will find .properties + moduleDir: "transports", + }, +} + +// Schema properties that intentionally have no Go counterpart and vice versa. +// Key is a JSON pointer path; value is a short reason. +var ignoreSchemaProps = map[string]string{ + "/properties/$schema": "JSON schema self-reference", + // GORM foreignKey slice relations that ARE user-submittable config input. + // Go-side: schemasync skips them via the gorm-tag filter; schema-side: + // these entries prevent the missing-in-go warning for them. + "/properties/governance/properties/virtual_keys/items/properties/provider_configs": "gorm fk slice; user-submittable", + "/properties/governance/properties/virtual_keys/items/properties/mcp_configs": "gorm fk slice; user-submittable", + "/properties/governance/properties/routing_rules/items/properties/targets": "gorm fk slice; user-submittable", + // MCP headers map — documented escape hatch is envFrom: + // plus env.X references in values; no chart-native secretRef. + "/properties/mcp/properties/client_configs/items/properties/headers/additionalProperties": "documented envFrom pattern", + // Object-storage identity fields (bucket/region/endpoint/project_id) are + // EnvVar-typed for flexibility but are not inherently secret. Operators + // can write `env.MY_VAR` in values and use envFrom to inject. Access + // keys, session tokens, and credentials DO have chart-native secret + // support via `storage.logsStore.objectStorage.existingSecret`. + "/properties/logs_store/properties/object_storage/properties/bucket": "not a secret; env.X + envFrom pattern", + "/properties/logs_store/properties/object_storage/properties/region": "not a secret; env.X + envFrom pattern", + "/properties/logs_store/properties/object_storage/properties/endpoint": "not a secret; env.X + envFrom pattern", + "/properties/logs_store/properties/object_storage/properties/project_id": "not a secret; env.X + envFrom pattern", +} + +// ignoreGoFields keys are "schemaPath|fieldName"; value is the reason. +var ignoreGoFields = map[string]string{ + "|auth_config": "deprecated; moved to governance.auth_config", +} + +// ignoreGoFieldNames are field names (regardless of parent path) that are +// DB bookkeeping or runtime-derived — never part of user-submitted config. +var ignoreGoFieldNames = map[string]string{ + "created_at": "DB bookkeeping", + "updated_at": "DB bookkeeping", + "config_hash": "internal hash", + "status": "runtime-derived", + "state": "runtime-derived", +} + +// opaqueLeafTypes are named Go types that have custom JSON marshalling and +// should be treated as leaves. The walker does NOT recurse into their fields, +// and they are collected for downstream checks (e.g., EnvVar → helm secret). +var opaqueLeafTypes = map[string]string{ + "github.com/maximhq/bifrost/core/schemas.EnvVar": "env-aware string; custom JSON", +} + +// envVarLocation records where an EnvVar-typed field appears in config.json +// so a downstream pass can confirm the helm chart supports Secret-backed +// injection (existingSecret / secretRef / env.BIFROST_*) for that path. +type envVarLocation struct { + schemaPath string + goPath string +} + +// Finding categorises every issue the tool surfaces so the final report can +// group by category and render as a table. +type Finding struct { + Category string // e.g. "missing-in-schema", "missing-in-go", "enum-drift", "envvar-no-secret" + Severity string // "ERROR" or "WARN" + Path string // schema path or enum path + Detail string // field name, Go path, missing/extra values, etc. + Go string // Go-side location (package.Type.Field) +} + +type checker struct { + schema map[string]any + pkgs map[string]*packages.Package // path → pkg + // enumConsts[namedType] -> list of string values found in any loaded package + enumConsts map[string][]string + // visited type names to break cycles + visited map[string]bool + // envVarFields records where EnvVar types occur, for downstream checks + envVarFields []envVarLocation + findings []Finding +} + +func main() { + schemaFlag := flag.String("schema", "transports/config.schema.json", "path to config.schema.json") + pkgDir := flag.String("pkg-root", ".", "repo root used as packages.Load dir") + helmValuesFlag := flag.String("helm-values", "helm-charts/bifrost/values.schema.json", "path to helm values.schema.json (for EnvVar secret-support check)") + helmHelpersFlag := flag.String("helm-helpers", "helm-charts/bifrost/templates/_helpers.tpl", "path to helm _helpers.tpl (for env.BIFROST_* emission detection)") + flag.Parse() + + schemaBytes, err := os.ReadFile(*schemaFlag) + if err != nil { + fmt.Fprintf(os.Stderr, "read schema: %v\n", err) + os.Exit(2) + } + var schema map[string]any + if err := json.Unmarshal(schemaBytes, &schema); err != nil { + fmt.Fprintf(os.Stderr, "parse schema: %v\n", err) + os.Exit(2) + } + + // Group entrypoints by moduleDir so we load each module's package graph once. + byModule := map[string][]entrypoint{} + orderedMods := []string{} + for _, e := range entrypoints { + if _, seen := byModule[e.moduleDir]; !seen { + orderedMods = append(orderedMods, e.moduleDir) + } + byModule[e.moduleDir] = append(byModule[e.moduleDir], e) + } + absRoot, err := filepath.Abs(*pkgDir) + if err != nil { + fmt.Fprintf(os.Stderr, "abs pkg-root: %v\n", err) + os.Exit(2) + } + // Always use the repo's go.work so local modules resolve against each + // other (not against registry tarballs). The tool refuses to run without + // go.work — that's the only configuration bifrost is tested against. + goworkPath := filepath.Join(absRoot, "go.work") + if _, err := os.Stat(goworkPath); err != nil { + fmt.Fprintf(os.Stderr, "schemasync requires go.work at %s: %v\n", goworkPath, err) + os.Exit(2) + } + + allPkgs := map[string]*packages.Package{} + for _, mod := range orderedMods { + modDir := filepath.Join(absRoot, mod) + env := append([]string{}, os.Environ()...) + env = append(env, "GOWORK="+goworkPath) + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedFiles, + Dir: modDir, + Env: env, + } + imports := []string{} + for _, e := range byModule[mod] { + imports = append(imports, e.pkg) + } + pkgs, err := packages.Load(cfg, imports...) + if err != nil { + fmt.Fprintf(os.Stderr, "load %s: %v\n", mod, err) + os.Exit(2) + } + hadLoadErr := false + packages.Visit(pkgs, nil, func(p *packages.Package) { + for _, e := range p.Errors { + fmt.Fprintln(os.Stderr, e) + hadLoadErr = true + } + }) + if hadLoadErr { + os.Exit(2) + } + for k, v := range collectPkgs(pkgs) { + allPkgs[k] = v + } + } + + c := &checker{ + schema: schema, + pkgs: allPkgs, + enumConsts: map[string][]string{}, + visited: map[string]bool{}, + } + c.collectConsts() + + for _, e := range entrypoints { + p := c.pkgs[e.pkg] + if p == nil { + c.add(Finding{Category: "entrypoint", Severity: "ERROR", Detail: "package not loaded: " + e.pkg}) + continue + } + obj := p.Types.Scope().Lookup(e.typeName) + if obj == nil { + c.add(Finding{Category: "entrypoint", Severity: "ERROR", Detail: fmt.Sprintf("type %s not found in %s", e.typeName, e.pkg)}) + continue + } + named, ok := obj.Type().(*types.Named) + if !ok { + c.add(Finding{Category: "entrypoint", Severity: "ERROR", Detail: fmt.Sprintf("%s.%s is not a named type", e.pkg, e.typeName)}) + continue + } + c.walkType(named, e.schemaPath, fmt.Sprintf("%s.%s", e.pkg, e.typeName)) + } + + // EnvVar → helm-chart secret-support pass. For each Go field typed as + // schemas.EnvVar, the helm chart must either (a) emit an env.BIFROST_* + // placeholder for that JSON path via _helpers.tpl, or (b) expose a + // secretRef/existingSecret knob in values.schema.json at the equivalent + // camelCase location. If neither, warn. + c.checkEnvVarHelmSupport(*helmValuesFlag, *helmHelpersFlag) + + printReport(os.Stderr, c.findings) + errCount := c.countErrs() + warnCount := c.countWarns() + if errCount > 0 { + fmt.Fprintf(os.Stderr, "\nschemasync: %d errors, %d warnings\n", errCount, warnCount) + os.Exit(1) + } + fmt.Fprintf(os.Stderr, "\nschemasync: OK (%d warnings)\n", warnCount) +} + +// printReport groups findings by category and prints a markdown-style table +// for each non-empty group. +func printReport(w interface{ Write([]byte) (int, error) }, findings []Finding) { + if len(findings) == 0 { + return + } + groups := map[string][]Finding{} + order := []string{} + for _, f := range findings { + if _, ok := groups[f.Category]; !ok { + order = append(order, f.Category) + } + groups[f.Category] = append(groups[f.Category], f) + } + titles := map[string]string{ + "missing-in-schema": "Missing in config.schema.json (Go has field, schema doesn't) — ERRORS", + "missing-in-go": "Missing in Go (schema has property, ConfigData doesn't) — WARNINGS", + "enum-drift": "Enum drift (Go constants vs schema enum array)", + "enum-no-schema": "Go enum types with no schema `enum` constraint — WARNINGS", + "envvar-no-secret": "EnvVar fields lacking chart-native Secret support — WARNINGS", + "schema-path-not-found": "Schema path not found for a walked Go type — ERRORS", + "entrypoint": "Entrypoint problems — ERRORS", + } + for _, cat := range order { + items := groups[cat] + title := titles[cat] + if title == "" { + title = cat + } + fmt.Fprintf(w.(interface{ Write([]byte) (int, error) }), "\n### %s (%d)\n\n", title, len(items)) + // Pick columns based on category for readability. + switch cat { + case "missing-in-schema", "schema-path-not-found": + renderTable(w, []string{"severity", "schema path", "Go location"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Go}) + } + return out + }()) + case "missing-in-go": + renderTable(w, []string{"severity", "schema path", "property", "Go parent"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Detail, f.Go}) + } + return out + }()) + case "enum-drift", "enum-no-schema": + renderTable(w, []string{"severity", "enum path", "drift", "Go location"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Detail, f.Go}) + } + return out + }()) + case "envvar-no-secret": + renderTable(w, []string{"severity", "config path", "Go location", "note"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Path, f.Go, f.Detail}) + } + return out + }()) + default: + renderTable(w, []string{"severity", "detail"}, func() [][]string { + out := [][]string{} + for _, f := range items { + out = append(out, []string{f.Severity, f.Detail}) + } + return out + }()) + } + } +} + +// renderTable writes a markdown table. Truncates long cells to keep width sane. +func renderTable(w interface{ Write([]byte) (int, error) }, headers []string, rows [][]string) { + const maxCol = 80 + widths := make([]int, len(headers)) + for i, h := range headers { + widths[i] = len(h) + } + truncate := func(s string) string { + if len(s) <= maxCol { + return s + } + return s[:maxCol-1] + "…" + } + trimmed := make([][]string, len(rows)) + for i, r := range rows { + trimmed[i] = make([]string, len(r)) + for j, cell := range r { + trimmed[i][j] = truncate(cell) + if j < len(widths) && len(trimmed[i][j]) > widths[j] { + widths[j] = len(trimmed[i][j]) + } + } + } + writeRow := func(cells []string) { + var sb strings.Builder + sb.WriteString("| ") + for i, c := range cells { + sb.WriteString(c) + if pad := widths[i] - len(c); pad > 0 { + sb.WriteString(strings.Repeat(" ", pad)) + } + sb.WriteString(" | ") + } + sb.WriteString("\n") + _, _ = w.Write([]byte(sb.String())) + } + writeRow(headers) + sep := make([]string, len(headers)) + for i := range headers { + sep[i] = strings.Repeat("-", widths[i]) + } + writeRow(sep) + for _, r := range trimmed { + writeRow(r) + } +} + +// checkEnvVarHelmSupport verifies that every Go field of type schemas.EnvVar +// has a way to be sourced from a Kubernetes secret via the helm chart. Proof +// of support is any of: +// +// 1. An `env.BIFROST_*` string literal appears in _helpers.tpl (indicating +// a rewrite is wired up for the corresponding config path), OR +// 2. values.schema.json declares a `secretRef` or `existingSecret` object +// at the camelCase equivalent of the schema path. +// +// Neither heuristic is perfect — this is a structural review aid, not a +// proof. Treat misses as warnings so they don't block CI on borderline cases. +func (c *checker) checkEnvVarHelmSupport(valuesPath, helpersPath string) { + helpersBytes, err := os.ReadFile(helpersPath) + if err != nil { + c.add(Finding{Category: "envvar-no-secret", Severity: "WARN", Detail: fmt.Sprintf("could not read helm helpers %s: %v — skipping EnvVar helm-support check", helpersPath, err)}) + return + } + helpers := string(helpersBytes) + // Extract every env.BIFROST_* token mentioned in _helpers.tpl. + envBifrostMentions := map[string]bool{} + for _, line := range strings.Split(helpers, "\n") { + // crude extraction: look for "env.BIFROST_X" substrings + idx := 0 + for idx < len(line) { + k := strings.Index(line[idx:], "env.BIFROST_") + if k < 0 { + break + } + start := idx + k + end := start + for end < len(line) { + ch := line[end] + if ch == '"' || ch == ' ' || ch == '\t' || ch == '}' || ch == ')' { + break + } + end++ + } + envBifrostMentions[line[start:end]] = true + idx = end + } + } + + valuesBytes, err := os.ReadFile(valuesPath) + hasValues := err == nil + var valuesSchema map[string]any + if hasValues { + _ = json.Unmarshal(valuesBytes, &valuesSchema) + } + + for _, loc := range c.envVarFields { + // Heuristic 1: any env.BIFROST_* is present in helpers — broad acceptance. + // We can't easily map a specific EnvVar field to a specific env var + // without per-field config, so we just check that the helpers file + // has AT LEAST ONE envBifrost mention that maps to this field's path. + // To make this stricter, we look for a helpers line mentioning either + // the camelCase field's parent path or an env var matching it. + camel := schemaPathToCamelCase(loc.schemaPath) + matched := false + // Heuristic 2: values.schema.json declares secretRef under the parent path. + if hasValues && valuesSchema != nil { + if hasSecretRefAt(valuesSchema, camel) { + matched = true + } + } + if !matched && len(envBifrostMentions) > 0 { + // Fall back to "some envBifrost wiring exists somewhere" — we flag it + // as a weaker hit so maintainers know to verify the mapping manually. + // Do not accept purely from presence; require a name-similarity match. + tail := lastSchemaComponent(loc.schemaPath) + for mention := range envBifrostMentions { + up := strings.ToUpper(tail) + if strings.Contains(mention, "_"+up) || strings.HasSuffix(mention, up) { + matched = true + break + } + } + } + if !matched { + if _, ignored := ignoreSchemaProps[loc.schemaPath]; ignored { + continue + } + c.add(Finding{ + Category: "envvar-no-secret", + Severity: "WARN", + Path: loc.schemaPath, + Detail: "helm has no secretRef/existingSecret at " + camel + " or parent", + Go: loc.goPath, + }) + } + } +} + +// schemaPathToCamelCase converts a JSON pointer like +// "/properties/governance/properties/auth_config/properties/admin_username" +// into a best-effort camelCase helm values path like +// "properties.bifrost.properties.governance.properties.authConfig.properties.adminUsername". +func schemaPathToCamelCase(p string) string { + parts := strings.Split(strings.TrimPrefix(p, "/"), "/") + out := []string{"properties", "bifrost"} + for _, part := range parts { + if part == "" { + continue + } + if part == "properties" { + out = append(out, "properties") + continue + } + out = append(out, snakeToCamel(part)) + } + return strings.Join(out, ".") +} + +func snakeToCamel(s string) string { + parts := strings.Split(s, "_") + for i := 1; i < len(parts); i++ { + if parts[i] != "" { + parts[i] = strings.ToUpper(parts[i][:1]) + parts[i][1:] + } + } + return strings.Join(parts, "") +} + +func lastSchemaComponent(p string) string { + parts := strings.Split(p, "/") + for i := len(parts) - 1; i >= 0; i-- { + if parts[i] != "" && parts[i] != "properties" { + return parts[i] + } + } + return "" +} + +// hasSecretRefAt returns true if EITHER (a) the target subtree declares a +// secretRef/existingSecret/*Secret knob inside its own "properties", OR +// (b) a SIBLING of the target (at the same properties-map level) is named +// "Secret" / "secretRef" / "existingSecret" / has "Secret" suffix. +// Sibling match is how the helm chart's encryptionKey + encryptionKeySecret +// pattern works: the Secret-source knob is a sibling of the field itself. +func hasSecretRefAt(schema map[string]any, dotted string) bool { + parts := strings.Split(dotted, ".") + var cur any = schema + var propsAtTarget map[string]any // map in which the last non-"properties" part lives + var targetName string + for _, p := range parts { + m, ok := cur.(map[string]any) + if !ok { + return false + } + // Resolve $ref at this node before descending. + if ref, ok := m["$ref"].(string); ok && strings.HasPrefix(ref, "#/") { + resolved := jsonPointerGet(schema, strings.TrimPrefix(ref, "#/")) + if rm, ok := resolved.(map[string]any); ok { + m = rm + } + } + if p != "properties" { + propsAtTarget = m + targetName = p + } + next, present := m[p] + if !present { + break + } + cur = next + } + // (a) target itself declares a Secret knob in its own properties. + if m, ok := cur.(map[string]any); ok && secretRefPresent(m) { + return true + } + // (b) a sibling of target matches Secret or a generic Secret knob. + if propsAtTarget != nil && targetName != "" { + for k := range propsAtTarget { + if k == targetName { + continue + } + if k == "secretRef" || k == "existingSecret" { + return true + } + if strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "SecretRef") { + return true + } + } + } + return false +} + +// jsonPointerGet resolves a /-delimited JSON Pointer into a schema root. +// Used by hasSecretRefAt to follow $ref entries in helm values.schema.json. +func jsonPointerGet(root any, pointer string) any { + if pointer == "" { + return root + } + parts := strings.Split(pointer, "/") + cur := root + for _, p := range parts { + p = strings.ReplaceAll(p, "~1", "/") + p = strings.ReplaceAll(p, "~0", "~") + m, ok := cur.(map[string]any) + if !ok { + return nil + } + cur = m[p] + if cur == nil { + return nil + } + } + return cur +} + +func secretRefPresent(m map[string]any) bool { + if m == nil { + return false + } + props, ok := m["properties"].(map[string]any) + if !ok { + return false + } + for k := range props { + if k == "secretRef" || k == "existingSecret" || k == "encryptionKeySecret" { + return true + } + if strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "SecretRef") { + return true + } + } + return false +} + +func collectPkgs(roots []*packages.Package) map[string]*packages.Package { + out := map[string]*packages.Package{} + packages.Visit(roots, nil, func(p *packages.Package) { + out[p.PkgPath] = p + }) + return out +} + +func (c *checker) add(f Finding) { c.findings = append(c.findings, f) } + +// countErrs returns the number of ERROR-severity findings. +func (c *checker) countErrs() int { + n := 0 + for _, f := range c.findings { + if f.Severity == "ERROR" { + n++ + } + } + return n +} + +// countWarns returns the number of WARN-severity findings. +func (c *checker) countWarns() int { + n := 0 + for _, f := range c.findings { + if f.Severity == "WARN" { + n++ + } + } + return n +} + +// collectConsts scans all loaded packages for `const X NamedStringType = "v"` +// and indexes them by namedType key "pkgpath.TypeName". +func (c *checker) collectConsts() { + for _, p := range c.pkgs { + if p.Types == nil { + continue + } + scope := p.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + cnst, ok := obj.(*types.Const) + if !ok { + continue + } + named, ok := cnst.Type().(*types.Named) + if !ok { + continue + } + // Only named string types + basic, ok := named.Underlying().(*types.Basic) + if !ok || basic.Kind() != types.String { + continue + } + key := named.Obj().Pkg().Path() + "." + named.Obj().Name() + v := cnst.Val() + if v.Kind() != constant.String { + continue + } + c.enumConsts[key] = append(c.enumConsts[key], constant.StringVal(v)) + } + } + for k := range c.enumConsts { + sort.Strings(c.enumConsts[k]) + } +} + +// walkType recursively walks a struct type, verifying each json-tagged field +// has a schema counterpart at the propagated schemaPath. +func (c *checker) walkType(t types.Type, schemaPath, goPath string) { + t = deref(t) + named, _ := t.(*types.Named) + if named != nil { + key := named.Obj().Pkg().Path() + "." + named.Obj().Name() + // Treat opaque types (like schemas.EnvVar) as leaves. + if _, isOpaque := opaqueLeafTypes[key]; isOpaque { + if key == "github.com/maximhq/bifrost/core/schemas.EnvVar" { + c.envVarFields = append(c.envVarFields, envVarLocation{schemaPath, goPath}) + } + return + } + if c.visited[key+"@"+schemaPath] { + return + } + c.visited[key+"@"+schemaPath] = true + } + structType, ok := t.Underlying().(*types.Struct) + if !ok { + return + } + + schemaNode := c.resolveSchema(schemaPath) + if schemaNode == nil { + c.add(Finding{Category: "schema-path-not-found", Severity: "ERROR", Path: schemaPath, Go: goPath}) + return + } + + // Collect every property key reachable from this schema node across + // properties/allOf/oneOf/anyOf/if-then-else branches. + schemaProps := c.collectProperties(schemaNode, schemaPath) + + goFieldTags := map[string]*types.Var{} + for i := 0; i < structType.NumFields(); i++ { + f := structType.Field(i) + if !f.Exported() { + continue + } + tag := reflectTag(structType.Tag(i), "json") + if tag == "" || tag == "-" { + continue + } + name := strings.Split(tag, ",")[0] + if name == "" { + continue + } + // Skip GORM relational fields (populated from joins; never user-submitted). + // GORM relational fields (`foreignKey`, `many2many`) are populated + // from DB joins, not user-submitted config. Skip them from the walk so + // schemasync only compares user-input config against config.schema.json. + // Schema properties for these relations may still exist (validated at + // the missing-in-go layer); add the schema path to ignoreSchemaProps + // for deliberate exceptions (see below for `provider_configs`, etc.). + gormTag := reflectTag(structType.Tag(i), "gorm") + if strings.Contains(gormTag, "foreignKey") || strings.Contains(gormTag, "many2many") { + continue + } + goFieldTags[name] = f + } + + // Go-field → schema check + for name, f := range goFieldTags { + childPath := schemaPath + "/properties/" + name + if _, ignored := ignoreGoFields[schemaPath+"|"+name]; ignored { + continue + } + if _, ignored := ignoreGoFieldNames[name]; ignored { + continue + } + childSchema := schemaProps[name] + if childSchema == nil { + c.add(Finding{ + Category: "missing-in-schema", + Severity: "ERROR", + Path: schemaPath + "/properties/" + name, + Detail: name, + Go: goPath + "." + f.Name(), + }) + continue + } + // Recurse into field type + c.walkField(f.Type(), childSchema, childPath, goPath+"."+f.Name()) + } + + // Schema-key → Go field check (warnings; schema may legitimately be broader) + for name := range schemaProps { + if _, ignored := ignoreSchemaProps[schemaPath+"/properties/"+name]; ignored { + continue + } + if _, ignored := ignoreGoFields[schemaPath+"|"+name]; ignored { + continue + } + if _, ok := goFieldTags[name]; !ok { + c.add(Finding{ + Category: "missing-in-go", + Severity: "WARN", + Path: schemaPath + "/properties/" + name, + Detail: name, + Go: goPath, + }) + } + } +} + +// walkField dispatches based on the field's Go type. +func (c *checker) walkField(t types.Type, schemaNode map[string]any, schemaPath, goPath string) { + t = deref(t) + + // Named type → opaque-leaf check + enum check (if string const type) + if named, ok := t.(*types.Named); ok { + key := named.Obj().Pkg().Path() + "." + named.Obj().Name() + if _, isOpaque := opaqueLeafTypes[key]; isOpaque { + if key == "github.com/maximhq/bifrost/core/schemas.EnvVar" { + c.envVarFields = append(c.envVarFields, envVarLocation{schemaPath, goPath}) + } + return // do not recurse into opaque types + } + if goVals, hasConsts := c.enumConsts[key]; hasConsts && len(goVals) > 0 { + c.checkEnum(goVals, schemaNode, schemaPath, goPath, key) + } + } + + switch u := t.Underlying().(type) { + case *types.Struct: + // Recurse into named struct (anonymous structs are inlined below) + if _, ok := t.(*types.Named); ok { + c.walkType(t, schemaPath, goPath) + } else { + // Anonymous inline struct — rare but handle by walking tags + c.walkAnonymous(u, schemaPath, goPath) + } + case *types.Slice: + elem := u.Elem() + if _, isStruct := deref(elem).Underlying().(*types.Struct); isStruct { + itemsNode := c.resolveRef(schemaNode) + if _, ok := itemsNode["items"].(map[string]any); ok { + c.walkType(elem, schemaPath+"/items", goPath+"[]") + } + } + case *types.Array: + elem := u.Elem() + if _, isStruct := deref(elem).Underlying().(*types.Struct); isStruct { + itemsNode := c.resolveRef(schemaNode) + if _, ok := itemsNode["items"].(map[string]any); ok { + c.walkType(elem, schemaPath+"/items", goPath+"[]") + } + } + case *types.Map: + elem := u.Elem() + if _, isStruct := deref(elem).Underlying().(*types.Struct); isStruct { + node := c.resolveRef(schemaNode) + if _, ok := node["additionalProperties"].(map[string]any); ok { + c.walkType(elem, schemaPath+"/additionalProperties", goPath+"[]") + } + // If no additionalProperties/patternProperties, silently skip — schemas + // often describe provider-keyed maps via oneOf branches. + } + case *types.Basic, *types.Interface: + // Leaf — nothing to recurse into. + } +} + +// walkAnonymous handles anonymous (inline) struct fields — rare; we treat +// them as a struct walk at the same schemaPath. +func (c *checker) walkAnonymous(st *types.Struct, schemaPath, goPath string) { + // Not common in this codebase; fall back to flat tag-check. + schemaNode := c.resolveSchema(schemaPath) + if schemaNode == nil { + return + } + props := c.collectProperties(schemaNode, schemaPath) + for i := 0; i < st.NumFields(); i++ { + f := st.Field(i) + if !f.Exported() { + continue + } + tag := reflectTag(st.Tag(i), "json") + if tag == "" || tag == "-" { + continue + } + name := strings.Split(tag, ",")[0] + if _, ok := props[name]; !ok { + c.add(Finding{ + Category: "missing-in-schema", + Severity: "ERROR", + Path: schemaPath + "/properties/" + name, + Detail: name, + Go: goPath + "." + f.Name(), + }) + } + } +} + +// checkEnum diffs Go string-const values against schema enum array. +func (c *checker) checkEnum(goVals []string, schemaNode map[string]any, schemaPath, goPath, typeKey string) { + node := c.resolveRef(schemaNode) + rawEnum, ok := node["enum"] + if !ok { + c.add(Finding{ + Category: "enum-no-schema", + Severity: "WARN", + Path: schemaPath, + Detail: fmt.Sprintf("%v (Go consts)", goVals), + Go: typeKey, + }) + return + } + enumArr, ok := rawEnum.([]any) + if !ok { + c.add(Finding{Category: "enum-drift", Severity: "ERROR", Path: schemaPath, Detail: "schema enum is not an array"}) + return + } + schemaSet := map[string]bool{} + for _, v := range enumArr { + if s, ok := v.(string); ok { + schemaSet[s] = true + } + } + goSet := map[string]bool{} + for _, v := range goVals { + goSet[v] = true + } + var missingInSchema, extraInSchema []string + for v := range goSet { + if !schemaSet[v] { + missingInSchema = append(missingInSchema, v) + } + } + for v := range schemaSet { + if !goSet[v] { + extraInSchema = append(extraInSchema, v) + } + } + sort.Strings(missingInSchema) + sort.Strings(extraInSchema) + if len(missingInSchema) > 0 { + c.add(Finding{ + Category: "enum-drift", + Severity: "ERROR", + Path: schemaPath, + Detail: fmt.Sprintf("schema missing Go consts %v", missingInSchema), + Go: goPath + " (" + typeKey + ")", + }) + } + if len(extraInSchema) > 0 { + c.add(Finding{ + Category: "enum-drift", + Severity: "WARN", + Path: schemaPath, + Detail: fmt.Sprintf("schema has %v with no Go const", extraInSchema), + Go: typeKey, + }) + } +} + +// collectProperties walks the schema subtree rooted at `node`, unioning +// property keys from the direct `properties`, and recursively from `allOf`, +// `oneOf`, `anyOf`, `then`, `else`. Handles $ref. +// Returns map of propertyName → subschema. +func (c *checker) collectProperties(node map[string]any, atPath string) map[string]map[string]any { + out := map[string]map[string]any{} + c.mergeProperties(out, node, atPath, map[string]bool{}) + return out +} + +func (c *checker) mergeProperties(out map[string]map[string]any, node map[string]any, atPath string, seen map[string]bool) { + if node == nil { + return + } + node = c.resolveRef(node) + if ref, ok := node["$ref"].(string); ok && seen[ref] { + return + } + if ref, ok := node["$ref"].(string); ok { + seen[ref] = true + } + if props, ok := node["properties"].(map[string]any); ok { + for k, v := range props { + if m, ok := v.(map[string]any); ok { + if _, already := out[k]; !already { + out[k] = m + } + } + } + } + for _, key := range []string{"allOf", "oneOf", "anyOf"} { + if arr, ok := node[key].([]any); ok { + for _, item := range arr { + if m, ok := item.(map[string]any); ok { + c.mergeProperties(out, m, atPath+"/"+key, seen) + } + } + } + } + for _, key := range []string{"then", "else"} { + if m, ok := node[key].(map[string]any); ok { + c.mergeProperties(out, m, atPath+"/"+key, seen) + } + } +} + +// resolveSchema walks a JSON-pointer path into c.schema, resolving $ref at +// each intermediate node. +func (c *checker) resolveSchema(path string) map[string]any { + parts := strings.Split(strings.TrimPrefix(path, "/"), "/") + var cur any = c.schema + for _, p := range parts { + if p == "" { + continue + } + m, ok := cur.(map[string]any) + if !ok { + return nil + } + m = c.resolveRef(m) + cur = m[unescapeJSONPointer(p)] + if cur == nil { + return nil + } + } + if m, ok := cur.(map[string]any); ok { + return c.resolveRef(m) + } + return nil +} + +// resolveRef follows a $ref pointer (recursively) to the final target node. +// $ref values are expected as "#/$defs/xxx" style JSON pointers. +func (c *checker) resolveRef(node map[string]any) map[string]any { + for i := 0; i < 16; i++ { + ref, ok := node["$ref"].(string) + if !ok { + return node + } + if !strings.HasPrefix(ref, "#/") { + return node // external refs unsupported + } + parts := strings.Split(strings.TrimPrefix(ref, "#/"), "/") + var cur any = c.schema + ok2 := true + for _, p := range parts { + m, isMap := cur.(map[string]any) + if !isMap { + ok2 = false + break + } + cur = m[unescapeJSONPointer(p)] + if cur == nil { + ok2 = false + break + } + } + if !ok2 { + return node + } + next, ok := cur.(map[string]any) + if !ok { + return node + } + node = next + } + return node +} + +func unescapeJSONPointer(s string) string { + s = strings.ReplaceAll(s, "~1", "/") + s = strings.ReplaceAll(s, "~0", "~") + return s +} + +// deref strips pointer wrappers to get the underlying type. +func deref(t types.Type) types.Type { + for { + p, ok := t.(*types.Pointer) + if !ok { + return t + } + t = p.Elem() + } +} + +// reflectTag parses a single struct-tag key; mirrors reflect.StructTag.Get. +func reflectTag(tag, key string) string { + for tag != "" { + for tag != "" && tag[0] == ' ' { + tag = tag[1:] + } + i := 0 + for i < len(tag) && tag[i] > ' ' && tag[i] != ':' && tag[i] != '"' && tag[i] != 0x7f { + i++ + } + if i == 0 || i+1 >= len(tag) || tag[i] != ':' || tag[i+1] != '"' { + break + } + name := tag[:i] + tag = tag[i+1:] + i = 1 + for i < len(tag) && tag[i] != '"' { + if tag[i] == '\\' { + i++ + } + i++ + } + if i >= len(tag) { + break + } + val := tag[1:i] + tag = tag[i+1:] + if name == key { + return val + } + } + return "" +} diff --git a/.github/workflows/scripts/setup-go-workspace.sh b/.github/workflows/scripts/setup-go-workspace.sh index a5effd3c49..29024fdaa3 100755 --- a/.github/workflows/scripts/setup-go-workspace.sh +++ b/.github/workflows/scripts/setup-go-workspace.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -euo pipefail - +export GOTOOLCHAIN=auto # If go.work exists, skip if [ -f "go.work" ]; then @@ -20,14 +20,16 @@ fi go work init go work use ./core go work use ./framework +go work use ./plugins/compat go work use ./plugins/governance go work use ./plugins/jsonparser -go work use ./plugins/litellmcompat go work use ./plugins/logging go work use ./plugins/maxim go work use ./plugins/mocker go work use ./plugins/otel +go work use ./plugins/prompts go work use ./plugins/semanticcache go work use ./plugins/telemetry go work use ./transports +go work use ./cli echo "✅ Go workspace initialized" diff --git a/.github/workflows/scripts/test-docker-image.sh b/.github/workflows/scripts/test-docker-image.sh index 5d770fbd64..ac115394bf 100755 --- a/.github/workflows/scripts/test-docker-image.sh +++ b/.github/workflows/scripts/test-docker-image.sh @@ -212,8 +212,7 @@ cat > "$CONFIG_FILE" << 'CONFIGEOF' "enable_logging": true, "enforce_governance_header": false, "allow_direct_keys": false, - "max_request_body_size_mb": 100, - "enable_litellm_fallbacks": false + "max_request_body_size_mb": 100 }, "encryption_key": "" } diff --git a/.github/workflows/scripts/validate-helm-config-fields.sh b/.github/workflows/scripts/validate-helm-config-fields.sh index ed7a17857d..7a78e1bf54 100755 --- a/.github/workflows/scripts/validate-helm-config-fields.sh +++ b/.github/workflows/scripts/validate-helm-config-fields.sh @@ -160,7 +160,7 @@ bifrost: enableLogging: true disableContentLogging: true disableDbPingsInHealth: true - logRetentionDays: 30 + logRetentionDays: 30 enforceGovernanceHeader: true allowDirectKeys: true maxRequestBodySizeMb: 50 @@ -168,6 +168,7 @@ bifrost: convertTextToChat: true convertChatToResponses: true shouldDropParams: true + shouldConvertParams: true prometheusLabels: - "team" - "env" @@ -206,6 +207,7 @@ assert_field_value 'client.max_request_body_size_mb' '.client.max_request_body_s assert_field_value 'client.compat.convert_text_to_chat' '.client.compat.convert_text_to_chat' 'true' assert_field_value 'client.compat.convert_chat_to_responses' '.client.compat.convert_chat_to_responses' 'true' assert_field_value 'client.compat.should_drop_params' '.client.compat.should_drop_params' 'true' +assert_field_value 'client.compat.should_convert_params' '.client.compat.should_convert_params' 'true' assert_field 'client.prometheus_labels' '.client.prometheus_labels' assert_field 'client.header_filter_config.allowlist' '.client.header_filter_config.allowlist' assert_field 'client.header_filter_config.denylist' '.client.header_filter_config.denylist' @@ -823,10 +825,10 @@ assert_field_value 'mcp client[0] tool_pricing.search' '.mcp.client_configs.[0]. assert_field_value 'mcp tool_manager_config.code_mode_binding_level' '.mcp.tool_manager_config.code_mode_binding_level' '"server"' ############################################################################### -# 8. Cluster, SAML, Load Balancer, Guardrails, Audit Logs +# 8. Cluster, SCIM, Load Balancer, Guardrails, Audit Logs ############################################################################### echo "" -echo -e "${CYAN}🌐 8/10 - Cluster, SAML, LB, Guardrails, Audit Logs${NC}" +echo -e "${CYAN}🌐 8/10 - Cluster, SCIM, LB, Guardrails, Audit Logs${NC}" echo "-----------------------------------------------------" cat > "$TMPDIR/values-cluster.yaml" << 'VALS' @@ -893,7 +895,8 @@ render_config "$TMPDIR/values-scim-okta.yaml" assert_field_value 'scim_config.enabled' '.scim_config.enabled' 'true' assert_field_value 'scim_config.provider' '.scim_config.provider' '"okta"' assert_field 'scim_config.config' '.scim_config.config' - +assert_field 'scim_config.config.apiToken' '.scim_config.config.apiToken' +assert_field 'scim_config.config.clientSecret' '.scim_config.config.clientSecret' # SCIM - Entra cat > "$TMPDIR/values-scim-entra.yaml" << 'VALS' image: @@ -916,6 +919,9 @@ VALS render_config "$TMPDIR/values-scim-entra.yaml" assert_field_value 'scim_config (entra) provider' '.scim_config.provider' '"entra"' assert_field 'scim_config (entra) config' '.scim_config.config' +assert_field_value 'scim_config (entra) enabled' '.scim_config.enabled' 'true' +assert_field 'scim_config (entra) config.tenantId' '.scim_config.config.tenantId' +assert_field 'scim_config (entra) config.clientId' '.scim_config.config.clientId' # Load Balancer cat > "$TMPDIR/values-lb.yaml" << 'VALS' @@ -1183,6 +1189,97 @@ assert_field_value 'logs_store.type (postgres)' '.logs_store.type' '"postgres"' assert_field_value 'logs_store.config.max_idle_conns' '.logs_store.config.max_idle_conns' '5' assert_field_value 'logs_store.config.max_open_conns' '.logs_store.config.max_open_conns' '50' +############################################################################### +# Object Storage (logsStore.objectStorage) +############################################################################### + +# S3 with inline credentials — exercises camelCase → snake_case mapping in _helpers.tpl +cat > "$TMPDIR/values-objstore-s3.yaml" << 'VALS' +image: + tag: v1.0.0 +storage: + mode: sqlite + configStore: + enabled: true + logsStore: + enabled: true + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + prefix: "prod" + compress: true + region: "us-east-1" + endpoint: "https://minio.internal:9000" + accessKeyId: "AKIA..." + secretAccessKey: "secret" + roleArn: "arn:aws:iam::123:role/bifrost" + forcePathStyle: true +VALS + +render_config "$TMPDIR/values-objstore-s3.yaml" +assert_field_value 'logs_store.object_storage.type (s3)' '.logs_store.object_storage.type' '"s3"' +assert_field_value 'logs_store.object_storage.bucket' '.logs_store.object_storage.bucket' '"bifrost-logs"' +assert_field_value 'logs_store.object_storage.prefix' '.logs_store.object_storage.prefix' '"prod"' +assert_field_value 'logs_store.object_storage.compress' '.logs_store.object_storage.compress' 'true' +assert_field_value 'logs_store.object_storage.region' '.logs_store.object_storage.region' '"us-east-1"' +assert_field_value 'logs_store.object_storage.endpoint' '.logs_store.object_storage.endpoint' '"https://minio.internal:9000"' +assert_field_value 'logs_store.object_storage.access_key_id' '.logs_store.object_storage.access_key_id' '"AKIA..."' +assert_field_value 'logs_store.object_storage.secret_access_key' '.logs_store.object_storage.secret_access_key' '"secret"' +assert_field_value 'logs_store.object_storage.role_arn' '.logs_store.object_storage.role_arn' '"arn:aws:iam::123:role/bifrost"' +assert_field_value 'logs_store.object_storage.force_path_style' '.logs_store.object_storage.force_path_style' 'true' + +# S3 with existingSecret — exercises env.BIFROST_OBJECT_STORAGE_* substitution path +cat > "$TMPDIR/values-objstore-s3-secret.yaml" << 'VALS' +image: + tag: v1.0.0 +storage: + mode: sqlite + configStore: + enabled: true + logsStore: + enabled: true + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + existingSecret: "bifrost-os-creds" + accessKeyIdKey: "access-key-id" + secretAccessKeyKey: "secret-access-key" + sessionTokenKey: "session-token" + roleArnKey: "role-arn" +VALS + +render_config "$TMPDIR/values-objstore-s3-secret.yaml" +assert_field_value 'logs_store.object_storage.access_key_id (env)' '.logs_store.object_storage.access_key_id' '"env.BIFROST_OBJECT_STORAGE_ACCESS_KEY_ID"' +assert_field_value 'logs_store.object_storage.secret_access_key (env)' '.logs_store.object_storage.secret_access_key' '"env.BIFROST_OBJECT_STORAGE_SECRET_ACCESS_KEY"' +assert_field_value 'logs_store.object_storage.session_token (env)' '.logs_store.object_storage.session_token' '"env.BIFROST_OBJECT_STORAGE_SESSION_TOKEN"' +assert_field_value 'logs_store.object_storage.role_arn (env)' '.logs_store.object_storage.role_arn' '"env.BIFROST_OBJECT_STORAGE_ROLE_ARN"' + +# GCS — exercises project_id + credentials_json mapping +cat > "$TMPDIR/values-objstore-gcs.yaml" << 'VALS' +image: + tag: v1.0.0 +storage: + mode: sqlite + configStore: + enabled: true + logsStore: + enabled: true + objectStorage: + enabled: true + type: gcs + bucket: "bifrost-gcs-bucket" + projectId: "my-gcp-project" + credentialsJson: "/etc/gcs/creds.json" +VALS + +render_config "$TMPDIR/values-objstore-gcs.yaml" +assert_field_value 'logs_store.object_storage.type (gcs)' '.logs_store.object_storage.type' '"gcs"' +assert_field_value 'logs_store.object_storage.bucket (gcs)' '.logs_store.object_storage.bucket' '"bifrost-gcs-bucket"' +assert_field_value 'logs_store.object_storage.project_id' '.logs_store.object_storage.project_id' '"my-gcp-project"' +assert_field_value 'logs_store.object_storage.credentials_json' '.logs_store.object_storage.credentials_json' '"/etc/gcs/creds.json"' + ############################################################################### # Summary ############################################################################### @@ -1200,4 +1297,4 @@ if [ "$TESTS_FAILED" -gt 0 ]; then else echo -e "${GREEN}✅ All config.json field validations passed!${NC}" exit 0 -fi +fi \ No newline at end of file diff --git a/.github/workflows/scripts/validate-helm-schema.sh b/.github/workflows/scripts/validate-helm-schema.sh index 80febb2e4c..5f012bc5c8 100755 --- a/.github/workflows/scripts/validate-helm-schema.sh +++ b/.github/workflows/scripts/validate-helm-schema.sh @@ -196,8 +196,8 @@ else echo "✅ VLLM key config required fields match: [$HELM_VLLM_REQUIRED]" fi -# Check concurrency_config required fields -CONFIG_CONCURRENCY_REQUIRED=$(jq -r '."$defs".concurrency_config.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") +# Check concurrency_and_buffer_size required fields (renamed from concurrency_config) +CONFIG_CONCURRENCY_REQUIRED=$(jq -r '."$defs".concurrency_and_buffer_size.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") HELM_CONCURRENCY_REQUIRED=$(jq -r '."$defs".concurrencyConfig.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") if [ "$CONFIG_CONCURRENCY_REQUIRED" != "$HELM_CONCURRENCY_REQUIRED" ]; then @@ -433,38 +433,17 @@ else echo "✅ MCP stdio config required fields match: [$CONFIG_MCP_STDIO_REQUIRED]" fi -# Check MCP websocket_config required fields -CONFIG_MCP_WS_REQUIRED=$(jq -r '."$defs".mcp_client_config.properties.websocket_config.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") -HELM_MCP_WS_REQUIRED=$(jq -r '."$defs".mcpClientConfig.properties.websocketConfig.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") - -if [ "$CONFIG_MCP_WS_REQUIRED" != "$HELM_MCP_WS_REQUIRED" ]; then - echo "❌ MCP websocket config required fields mismatch:" - echo " Config: [$CONFIG_MCP_WS_REQUIRED]" - echo " Helm: [$HELM_MCP_WS_REQUIRED]" - ERRORS=$((ERRORS + 1)) -else - echo "✅ MCP websocket config required fields match: [$CONFIG_MCP_WS_REQUIRED]" -fi - -# Check MCP http_config required fields -CONFIG_MCP_HTTP_REQUIRED=$(jq -r '."$defs".mcp_client_config.properties.http_config.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") -HELM_MCP_HTTP_REQUIRED=$(jq -r '."$defs".mcpClientConfig.properties.httpConfig.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") - -if [ "$CONFIG_MCP_HTTP_REQUIRED" != "$HELM_MCP_HTTP_REQUIRED" ]; then - echo "❌ MCP http config required fields mismatch:" - echo " Config: [$CONFIG_MCP_HTTP_REQUIRED]" - echo " Helm: [$HELM_MCP_HTTP_REQUIRED]" - ERRORS=$((ERRORS + 1)) -else - echo "✅ MCP http config required fields match: [$CONFIG_MCP_HTTP_REQUIRED]" -fi +# MCP websocket_config and http_config were removed from config.schema.json +# because the corresponding Go fields don't exist (MCP rendering uses +# connection_type + connection_string directly, not sub-object configs). +# Helm still declares them for user convenience — not a schema sync concern. echo "" echo "🔍 Checking required fields in SAML/SCIM config..." # Check okta_config required fields CONFIG_OKTA_REQUIRED=$(jq -r '."$defs".okta_config.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") -HELM_OKTA_REQUIRED=$(jq -r '.properties.bifrost.properties.saml.allOf[0].then.properties.config.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") +HELM_OKTA_REQUIRED=$(jq -r '.properties.bifrost.properties.scim.allOf[0].then.properties.config.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") if [ "$CONFIG_OKTA_REQUIRED" != "$HELM_OKTA_REQUIRED" ]; then echo "❌ Okta config required fields mismatch:" @@ -477,7 +456,7 @@ fi # Check entra_config required fields CONFIG_ENTRA_REQUIRED=$(jq -r '."$defs".entra_config.required // [] | sort | join(",")' "$CONFIG_SCHEMA" 2>/dev/null || echo "") -HELM_ENTRA_REQUIRED=$(jq -r '.properties.bifrost.properties.saml.allOf[1].then.properties.config.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") +HELM_ENTRA_REQUIRED=$(jq -r '.properties.bifrost.properties.scim.allOf[1].then.properties.config.required // [] | sort | join(",")' "$HELM_SCHEMA" 2>/dev/null || echo "") if [ "$CONFIG_ENTRA_REQUIRED" != "$HELM_ENTRA_REQUIRED" ]; then echo "❌ Entra config required fields mismatch:" diff --git a/.github/workflows/scripts/validate-schema-sync.sh b/.github/workflows/scripts/validate-schema-sync.sh new file mode 100755 index 0000000000..0214d0679b --- /dev/null +++ b/.github/workflows/scripts/validate-schema-sync.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Validate that Go config types in transports/bifrost-http/lib/config.go +# stay in sync (fields + enum values) with transports/config.schema.json. +# Walks the type graph recursively via go/types rather than regex-parsing source. + +if command -v readlink >/dev/null 2>&1 && readlink -f "$0" >/dev/null 2>&1; then + SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +else + SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd -P)" +fi +REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)" +TOOL_DIR="$SCRIPT_DIR/schemasync" + +cd "$REPO_ROOT" + +if ! command -v go >/dev/null 2>&1; then + echo "❌ go toolchain required for schema-sync validation" + exit 2 +fi + +# Ensure go.work exists at the repo root. schemasync's packages.Load needs +# it to resolve bifrost's local modules against each other. On fresh CI +# runners go.work is not checked in, so we provision it here inline. +# Sibling scripts (test-bifrost-http.sh etc.) call setup-go-workspace.sh +# via `source`, but that relies on the `return` builtin which has +# platform-dependent edge cases under `set -e`; we instead do the same +# work inline so this wrapper is self-contained. +if [ ! -f "$REPO_ROOT/go.work" ]; then + echo "🔧 Setting up Go workspace (go.work not found)..." + ( + cd "$REPO_ROOT" + go work init + for mod in ./core ./framework \ + ./plugins/compat ./plugins/governance ./plugins/jsonparser \ + ./plugins/logging ./plugins/maxim ./plugins/mocker \ + ./plugins/otel ./plugins/prompts ./plugins/semanticcache \ + ./plugins/telemetry \ + ./transports ./cli; do + if [ -f "$REPO_ROOT/$mod/go.mod" ]; then + go work use "$mod" + fi + done + ) + echo "✅ Go workspace initialized at $REPO_ROOT/go.work" +else + echo "🔍 Go workspace already exists at $REPO_ROOT/go.work, skipping initialization" +fi + +echo "🔍 Validating Go ↔ config.schema.json sync (recursive, AST-based)" +echo "==================================================================" + +# The schemasync tool is its own module (separate go.mod). Build it with +# GOWORK=off so the tool's deps (golang.org/x/tools) resolve against its +# own go.mod, not the repo's go.work. At runtime the tool itself sets +# GOWORK=/go.work when loading bifrost packages. +(cd "$TOOL_DIR" && GOWORK=off go build -o /tmp/schemasync .) +/tmp/schemasync \ + --schema "$REPO_ROOT/transports/config.schema.json" \ + --pkg-root "$REPO_ROOT" \ + --helm-values "$REPO_ROOT/helm-charts/bifrost/values.schema.json" \ + --helm-helpers "$REPO_ROOT/helm-charts/bifrost/templates/_helpers.tpl" diff --git a/.github/workflows/snyk.yml b/.github/workflows/snyk.yml index 5478ca835e..cfb3fe7b82 100644 --- a/.github/workflows/snyk.yml +++ b/.github/workflows/snyk.yml @@ -16,10 +16,29 @@ jobs: name: Snyk Open Source (deps) runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + api.snyk.io:443 + downloads.snyk.io:443 + files.pythonhosted.org:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + iojs.org:443 + nodejs.org:443 + packages.microsoft.com:443 + proxy.golang.org:443 + raw.githubusercontent.com:443 + registry.npmjs.org:443 + release-assets.githubusercontent.com:443 + releases.astral.sh:443 + static.snyk.io:443 + storage.googleapis.com:443 + sum.golang.org:443 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -42,7 +61,7 @@ jobs: - name: Setup Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Setup Go workspace run: make setup-workspace @@ -70,10 +89,29 @@ jobs: name: Snyk Code (SAST) runs-on: ubuntu-latest steps: - - name: Harden the runner (Audit all outbound calls) + - name: Harden Runner uses: step-security/harden-runner@fa2e9d605c4eeb9fcad4c99c224cee0c6c7f3594 # v2.16.0 with: - egress-policy: audit + egress-policy: block + allowed-endpoints: > + api.github.com:443 + api.snyk.io:443 + deeproxy.snyk.io:443 + downloads.snyk.io:443 + files.pythonhosted.org:443 + fonts.googleapis.com:443 + fonts.gstatic.com:443 + github.com:443 + iojs.org:443 + nodejs.org:443 + packages.microsoft.com:443 + proxy.golang.org:443 + raw.githubusercontent.com:443 + registry.npmjs.org:443 + release-assets.githubusercontent.com:443 + releases.astral.sh:443 + storage.googleapis.com:443 + sum.golang.org:443 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 @@ -96,7 +134,7 @@ jobs: - name: Setup Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: "1.26.2" + go-version: "1.26.1" - name: Setup Go workspace run: make setup-workspace diff --git a/.gitignore b/.gitignore index 8c178eee8f..8599e1f640 100644 --- a/.gitignore +++ b/.gitignore @@ -119,4 +119,56 @@ terraform.tfstate.backup !*.tfvars.example # Bifrost benchmarking -bifrost-benchmarking \ No newline at end of file +bifrost-benchmarking + +# Tests +:memory: + +# Generated test TLS certs (created by tests/docker-compose.yml redis-certs-init) +tests/redis-certs/ + +# dependencies +ui/node_modules +ui/.pnp +.pnp.* +.yarn/* +!.yarn/patches +!.yarn/plugins +!.yarn/releases +!.yarn/versions + +# testing +ui/coverage + +# next.js +ui/.next/ +ui/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# env files (can opt-in for committing if needed) +.env* + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +# auto-generated TanStack Router route tree +ui/app/routeTree.gen.ts + +.tanstack +.next \ No newline at end of file diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 0000000000..1d9b7831ba --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22.12.0 diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 4da40ee345..0000000000 --- a/.prettierrc +++ /dev/null @@ -1,25 +0,0 @@ -{ - "root": true, - "printWidth": 140, - "singleQuote": false, - "bracketSpacing": true, - "semi": true, - "bracketSameLine": false, - "useTabs": true, - "tabWidth": 2, - "trailingComma": "all", - "plugins": [ - "prettier-plugin-tailwindcss" - ], - "pluginSearchDirs": [ - "./ui" - ], - "tailwindAttributes": [ - "buttonClassname" - ], - "tailwindFunctions": [ - "cn", - "classNames" - ], - "endOfLine": "lf" -} \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md index 03fdd812ee..f6356421ea 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -103,9 +103,9 @@ bifrost/ │ ├── mocker/ # Mock responses for testing │ ├── jsonparser/ # JSON extraction utilities │ ├── maxim/ # Maxim observability -│ └── litellmcompat/ # LiteLLM SDK compatibility (HTTP transport) +│ └── compat/ # LiteLLM SDK compatibility (HTTP transport) │ -├── ui/ # Next.js web interface +├── ui/ # React + vite web interface │ ├── app/workspace/ # Feature pages (20+ workspace sections) │ ├── components/ # Shared React components │ └── lib/ # Constants, utilities, types @@ -641,10 +641,266 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso ## Code Style - **Go**: `gofmt`/`goimports`. No custom linter config. -- **TypeScript/React**: Prettier. Next.js App Router. +- **TypeScript/React**: Oxfmt. TanStack Router. - **JSON tags**: `snake_case` matching provider API conventions. - **Error strings**: Lowercase, no trailing punctuation (Go convention). - **Provider types**: Prefixed with provider name in PascalCase (`AnthropicChatRequest`, `GeminiEmbeddingResponse`). - **Converter functions**: Pure — no side effects, no logging, no HTTP. - **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`). - **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. + +# Frontend Code Guidelines & Patterns + +This document defines the standards, structure, and best practices for writing frontend code in this project. + +--- + +## Tech Stack + +- **React** (with Vite) +- **TypeScript** +- **@tanstack/react-router** (type-safe routing) +- **Tailwind CSS v4** +- **Radix UI** (primitives) +- **Local UI component library** (`ui/components/ui/`) built on Radix primitives + +--- + +## Folder Structure + +```text + +/ui +├── app # Routes & pages +├── components # Shared components +│ └── ui # Core design system components +├── hooks # Custom React hooks +├── lib # Utilities, helpers, shared logic +└── app/enterprise # Enterprise-specific code (via symlink) + +``` + +### Rules + +- All frontend code must live inside `/ui` +- Routes and pages → `ui/app` +- Shared/reusable components → `ui/components` +- Core UI primitives → `ui/components/ui` +- Utilities and libraries → `ui/lib` +- Custom hooks → `ui/hooks` + +--- + +## Libraries & Usage + +### Core Libraries + +- `react` → UI library +- `typescript` → Type safety +- `tailwindcss` → Styling +- `@tanstack/react-router` → Routing + +### UI & Visualization + +- `@radix-ui/react-*` → UI primitives +- `ui/components/ui/*` → Project's Radix-based component system +- `recharts` → Charts +- `monaco-editor` → Code editor + +### Utilities + +- `date-fns` → Date/time formatting +- `nuqs` → Query param state management + +### Tooling + +- `Oxfmt` → Code formatting +- `vitest` → Testing + +--- + +## Routing Convention + +For every new route: + +```text + +ui/app// +├── layout.tsx # Route definition using createFileRoute +├── page.tsx # Page content +└── views/ # Optional: route-specific components + +``` + +### Rules + +- Folder name must match route name +- Always use `createFileRoute` in `layout.tsx` +- `page.tsx` should only handle composition (not heavy logic) +- Route-specific components go inside `views/` + +--- + +## Component Guidelines + +### Reusability First + +- Always check if similar components/functions already exist +- Prefer extending or refactoring existing code over duplication +- Only create new components if reuse is not feasible + +--- + +### Component Placement + +- Shared → `ui/components` +- Route-specific → `views/` inside route folder + +--- + +### JSX & Rendering + +- Avoid deeply nested conditional rendering +- Break complex UI into smaller components +- Keep components readable and maintainable + +--- + +### Lists & Keys + +- Always use **stable, unique keys** +- Never use array index as key (unless unavoidable) + +--- + +## React Best Practices + +- Avoid unnecessary or unstable dependencies in hooks +- Prevent infinite loops in `useEffect` +- Keep dependency arrays accurate and minimal +- Prefer derived state over duplicated state + +--- + +## State Management + +### Priority Order + +1. Query Params (`nuqs`) → for persistent/shareable state +2. Local State → for UI-only state +3. Redux → only when truly necessary + +--- + +### Query Params (`nuqs`) + +- Use for state that should persist across refresh/navigation +- Use proper parsers like `parseAsString` or `parseAsInteger` +- Do NOT mix query param state with local/redux state +- Follow a single consistent pattern across the codebase + +--- + +### Redux + +- Use only when global/shared state is required +- Avoid unnecessary slices +- Prefer simpler alternatives when possible + +--- + +### RTK Query (`@reduxjs/toolkit/query`) + +- Use for API calls and caching +- Use **granular tags** for cache invalidation +- Avoid invalidating entire datasets unnecessarily +- Implement **optimistic updates** where applicable + +--- + +## Forms + +We use: + +- `react-hook-form` +- `zod v4` (for schema validation) + +### Rules + +- Always define a Zod schema +- Include meaningful validation messages +- Prefer **inline field errors** (not toast notifications) +- Use `refine` / `superRefine` for complex validation +- Store schemas in: `ui/lib/types/schemas.ts` + +--- + +## Tables + +- Use `@tanstack/react-table` **only for large/complex datasets** +- For simple tables → build custom lightweight components +- Prioritize performance over abstraction + +--- + +## ⚡ Performance Guidelines + +- Lazy load heavy or rarely-used libraries +- Avoid unnecessary re-renders +- Split large components into smaller ones +- Keep bundle size minimal + +--- + +## Dependency Rules + +- Do NOT add new dependencies unless absolutely necessary +- Always pin exact versions (no `^` or `~`) +- Prefer existing libraries in the codebase + +--- + +## TypeScript Guidelines + +- Avoid using `any` unless absolutely unavoidable +- Prefer strict typing and inference +- Define reusable types in shared locations + +--- + +## Code Quality & Formatting + +After writing code: + +```bash +cd ui && npm run format +```` + +Then verify build: + +```bash +cd ui && npm run build +``` + +* Code must pass formatting and build checks +* Follow consistent naming and structure conventions + +--- + +## Anti-Patterns to Avoid + +* Duplicate components without considering reuse +* Mixing multiple state management approaches unnecessarily +* Overusing Redux +* Using unstable hook dependencies +* Adding heavy libraries for simple use cases +* Poorly structured or deeply nested JSX + +--- + +## Summary + +* Prioritize **reusability, performance, and consistency** +* Follow **strict folder structure and routing conventions** +* Use **the right tool for the right problem** +* Keep code **simple, predictable, and maintainable** diff --git a/Makefile b/Makefile index 9a0df1863b..84f2d66b28 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,17 @@ CYAN=\033[0;36m NC=\033[0m # No Color ECHO := printf '%b\n' -.PHONY: all help dev build-ui build build-cli run run-cli install-air clean test test-cli install-ui setup-workspace work-init work-clean docs docker-image docker-run cleanup-enterprise mod-tidy test-integrations-py test-integrations-ts install-playwright run-e2e run-e2e-ui run-e2e-headed +# nvm requires bash-compatible shell semantics; /bin/sh is dash on some Linux distros. +SHELL := /bin/bash + +# Ensures the Node version pinned in .nvmrc is active before any npm/node call. +# nvm is a shell function, so each recipe that needs it must inline this snippet +# via `$(USE_NODE); `. +USE_NODE = NVM_SH="$${NVM_DIR:-$$HOME/.nvm}/nvm.sh"; \ + [ -s "$$NVM_SH" ] || NVM_SH="$$(brew --prefix nvm 2>/dev/null)/nvm.sh"; \ + if [ -s "$$NVM_SH" ]; then . "$$NVM_SH" >/dev/null && nvm install >/dev/null 2>&1 && nvm use >/dev/null 2>&1; fi + +.PHONY: all help dev dev-pulse build-ui build build-cli run run-cli install-air install-pulse clean test test-cli install-ui setup-workspace work-init work-clean docs docker-image docker-run cleanup-enterprise mod-tidy test-integrations-py test-integrations-ts install-playwright run-e2e run-e2e-ui run-e2e-headed all: help @@ -61,17 +71,26 @@ cleanup-enterprise: ## Clean up enterprise directories if present @$(ECHO) "$(GREEN)Enterprise cleaned up$(NC)" install-ui: cleanup-enterprise - @which node > /dev/null || ($(ECHO) "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1) - @which npm > /dev/null || ($(ECHO) "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1) - @$(ECHO) "$(GREEN)Node.js and npm are installed$(NC)" - @cd ui && npm ci - @which next > /dev/null || ($(ECHO) "$(YELLOW)Installing nextjs...$(NC)" && npm install -g next) + @$(USE_NODE); \ + which node > /dev/null || ($(ECHO) "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1); \ + which npm > /dev/null || ($(ECHO) "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1); \ + $(ECHO) "$(GREEN)Node.js $$(node -v) and npm $$(npm -v) are installed$(NC)"; \ + if [ ! -d "ui/node_modules" ] || [ "ui/package.json" -nt "ui/node_modules/.package-lock.json" ] || [ "ui/package-lock.json" -nt "ui/node_modules/.package-lock.json" ]; then \ + $(ECHO) "$(YELLOW)Dependencies changed, running npm ci...$(NC)"; \ + cd ui && npm ci; \ + else \ + $(ECHO) "$(GREEN)UI dependencies up to date, skipping install$(NC)"; \ + fi @$(ECHO) "$(GREEN)UI deps are in sync$(NC)" install-air: ## Install air for hot reloading (if not already installed) @which air > /dev/null || ($(ECHO) "$(YELLOW)Installing air for hot reloading...$(NC)" && go install github.com/air-verse/air@latest) @$(ECHO) "$(GREEN)Air is ready$(NC)" +install-pulse: ## Install pulse for hot reloading (if not already installed) + @which pulse > /dev/null || ($(ECHO) "$(YELLOW)Installing pulse for hot reloading...$(NC)" && go install github.com/Pratham-Mishra04/pulse@latest) + @$(ECHO) "$(GREEN)Pulse is ready$(NC)" + install-delve: ## Install delve for debugging (if not already installed) @which dlv > /dev/null || ($(ECHO) "$(YELLOW)Installing delve for debugging...$(NC)" && go install github.com/go-delve/delve/cmd/dlv@latest) @$(ECHO) "$(GREEN)Delve is ready$(NC)" @@ -86,6 +105,7 @@ install-junit-viewer: ## Install junit-viewer for HTML report generation (if not $(ECHO) "$(GREEN)junit-viewer is already installed$(NC)"; \ else \ $(ECHO) "$(YELLOW)Installing junit-viewer for HTML reports...$(NC)"; \ + $(USE_NODE); \ if npm install -g junit-viewer 2>&1; then \ $(ECHO) "$(GREEN)junit-viewer installed successfully$(NC)"; \ else \ @@ -114,9 +134,9 @@ dev: install-ui install-air setup-workspace $(if $(DEBUG),install-delve) ## Star fi @$(ECHO) "" @$(ECHO) "$(YELLOW)Starting UI development server...$(NC)" - @if [ -n "$(DISABLE_PROFILER)" ]; then \ + @$(USE_NODE); if [ -n "$(DISABLE_PROFILER)" ]; then \ $(ECHO) "$(CYAN)DevProfiler disabled for testing$(NC)"; \ - cd ui && NEXT_PUBLIC_DISABLE_PROFILER=1 npm run dev & \ + cd ui && BIFROST_DISABLE_PROFILER=1 npm run dev & \ else \ cd ui && npm run dev & \ fi @@ -147,10 +167,59 @@ dev: install-ui install-air setup-workspace $(if $(DEBUG),install-delve) ## Star $(if $(APP_DIR),-app-dir "$(APP_DIR)"); \ fi +dev-pulse: install-ui install-pulse setup-workspace $(if $(DEBUG),install-delve) ## Start complete development environment using pulse for hot reloading + @$(ECHO) "$(GREEN)Starting Bifrost complete development environment (pulse)...$(NC)" + @$(ECHO) "$(YELLOW)This will start:$(NC)" + @$(ECHO) " 1. UI development server (localhost:3000)" + @$(ECHO) " 2. API server with UI proxy (localhost:$(PORT))" + @$(ECHO) "$(CYAN)Access everything at: http://localhost:$(PORT)$(NC)" + @if [ -n "$(DEBUG)" ]; then \ + $(ECHO) "$(CYAN) 3. Debugger (delve) listening on port 2345$(NC)"; \ + fi + @if [ ! -d "transports/bifrost-http/ui" ]; then \ + $(ECHO) "$(YELLOW)Creating transports/bifrost-http/ui directory...$(NC)"; \ + mkdir -p transports/bifrost-http/ui; \ + touch transports/bifrost-http/ui/.tmp; \ + fi + @$(ECHO) "" + @$(ECHO) "$(YELLOW)Starting UI development server...$(NC)" + @$(USE_NODE); if [ -n "$(DISABLE_PROFILER)" ]; then \ + $(ECHO) "$(CYAN)DevProfiler disabled for testing$(NC)"; \ + cd ui && BIFROST_DISABLE_PROFILER=1 npm run dev & \ + else \ + cd ui && npm run dev & \ + fi + @sleep 3 + @$(ECHO) "$(YELLOW)Starting API server with UI proxy...$(NC)" + @$(MAKE) setup-workspace >/dev/null + @if [ -f .env ]; then \ + $(ECHO) "$(YELLOW)Loading environment variables from .env...$(NC)"; \ + set -a; . ./.env; set +a; \ + fi; \ + if [ -n "$(DEBUG)" ]; then \ + $(ECHO) "$(CYAN)Starting with pulse + delve debugger on port 2345...$(NC)"; \ + $(ECHO) "$(YELLOW)Attach your debugger to localhost:2345$(NC)"; \ + BIFROST_UI_DEV=true pulse -- \ + -host "$(HOST)" \ + -port "$(PORT)" \ + -log-style "$(LOG_STYLE)" \ + -log-level "$(LOG_LEVEL)" \ + $(if $(PROMETHEUS_LABELS),-prometheus-labels "$(PROMETHEUS_LABELS)") \ + $(if $(APP_DIR),-app-dir "$(APP_DIR)"); \ + else \ + BIFROST_UI_DEV=true pulse -- \ + -host "$(HOST)" \ + -port "$(PORT)" \ + -log-style "$(LOG_STYLE)" \ + -log-level "$(LOG_LEVEL)" \ + $(if $(PROMETHEUS_LABELS),-prometheus-labels "$(PROMETHEUS_LABELS)") \ + $(if $(APP_DIR),-app-dir "$(APP_DIR)"); \ + fi + build-ui: install-ui ## Build ui @$(ECHO) "$(GREEN)Building ui...$(NC)" @rm -rf ui/.next - @cd ui && npm run build && npm run copy-build + @$(USE_NODE); cd ui && npm run build && npm run copy-build build: build-ui ## Build bifrost-http binary @if [ -n "$(LOCAL)" ]; then \ @@ -828,7 +897,8 @@ test-governance: install-gotestsum $(if $(DEBUG),install-delve) ## Run governanc setup-mcp-tests: ## Build all MCP test servers in examples/mcps/ (Go and TypeScript) @$(ECHO) "$(GREEN)Building MCP test servers...$(NC)" - @FAILED=0; \ + @$(USE_NODE); \ + FAILED=0; \ for mcp_dir in examples/mcps/*/; do \ if [ -d "$$mcp_dir" ]; then \ mcp_name=$$(basename $$mcp_dir); \ @@ -1200,6 +1270,7 @@ test-integrations-ts: ## Run TypeScript integration tests (Usage: make test-inte done; \ fi; \ TEST_FAILED=0; \ + $(USE_NODE); \ if ! which npm > /dev/null 2>&1; then \ $(ECHO) "$(RED)Error: npm not found$(NC)"; \ $(ECHO) "$(YELLOW)Install Node.js: https://nodejs.org/$(NC)"; \ @@ -1258,7 +1329,7 @@ install-playwright: ## Install Playwright test dependencies @$(ECHO) "$(GREEN)Installing Playwright dependencies...$(NC)" @which node > /dev/null || ($(ECHO) "$(RED)Error: Node.js is not installed. Please install Node.js first.$(NC)" && exit 1) @which npm > /dev/null || ($(ECHO) "$(RED)Error: npm is not installed. Please install npm first.$(NC)" && exit 1) - @cd tests/e2e && npm ci + @$(USE_NODE); cd tests/e2e && npm ci @cd tests/e2e && if npx playwright install --list 2>/dev/null | grep -q "chromium"; then \ $(ECHO) "$(CYAN)Chromium is already installed, skipping download$(NC)"; \ else \ diff --git a/README.md b/README.md index 9e843c58fc..f957c58ee1 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/maximhq/bifrost/core)](https://goreportcard.com/report/github.com/maximhq/bifrost/core) [![Discord badge](https://dcbadge.limes.pink/api/server/https://discord.gg/exN5KAydbU?style=flat)](https://discord.gg/exN5KAydbU) -[![Known Vulnerabilities](https://snyk.io/test/github/maximhq/bifrost/badge.svg)](https://snyk.io/test/github/maximhq/bifrost) [![codecov](https://codecov.io/gh/maximhq/bifrost/branch/main/graph/badge.svg)](https://codecov.io/gh/maximhq/bifrost) ![Docker Pulls](https://img.shields.io/docker/pulls/maximhq/bifrost) [Run In Postman](https://app.getpostman.com/run-collection/31642484-2ba0e658-4dcd-49f4-845a-0c7ed745b916?action=collection%2Ffork&source=rip_markdown&collection-url=entityId%3D31642484-2ba0e658-4dcd-49f4-845a-0c7ed745b916%26entityType%3Dcollection%26workspaceId%3D63e853c8-9aec-477f-909c-7f02f543150e) diff --git a/cli/go.mod b/cli/go.mod index c887aac505..f7260a3b86 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -1,6 +1,6 @@ module github.com/maximhq/bifrost/cli -go 1.26.2 +go 1.26.1 require ( github.com/bytedance/sonic v1.15.0 diff --git a/core/bifrost.go b/core/bifrost.go index 7db6663758..83ec4777af 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -5,8 +5,8 @@ package bifrost import ( "context" + "errors" "fmt" - "math/rand" "slices" "sort" "strings" @@ -17,6 +17,7 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" + "github.com/maximhq/bifrost/core/keyselectors" "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/mcp/codemode/starlark" "github.com/maximhq/bifrost/core/providers/anthropic" @@ -89,12 +90,41 @@ type Bifrost struct { // ProviderQueue wraps a provider's request channel with lifecycle management // to prevent "send on closed channel" panics during provider removal/update. // Producers must check the closing flag or select on the done channel before sending. +// +// Why pq.queue is NEVER closed: +// +// Closing a channel in Go causes any concurrent send to that channel to panic +// ("send on closed channel"). There is always a TOCTOU window between a +// producer's isClosing() check and its select { case pq.queue <- msg: ... }: +// the producer could pass isClosing() while the queue is open, get preempted, +// and resume only after the queue is closed. Go's selectgo evaluates select +// cases in a random order, so even having case <-pq.done: in the same select +// does not protect against this — if selectgo evaluates the send case first on +// a closed channel it panics immediately via goto sclose, before reaching done. +// +// To close pq.queue safely you would need a sender-side WaitGroup so that +// signalClosing could wait for every in-flight producer to finish. That adds +// non-trivial overhead on the hot request path. +// +// Instead, pq.done is the sole shutdown signal. Receiving from a closed channel +// is always safe (returns the zero value immediately), so: +// - Workers exit via case <-pq.done: — safe +// - Producers bail via case <-pq.done: — safe +// - drainQueueWithErrors handles any messages that slip through the TOCTOU window +// +// pq.queue is garbage collected automatically: +// - RemoveProvider calls requestQueues.Delete, dropping the map's reference. +// - UpdateProvider calls requestQueues.Store with a new queue, dropping the +// map's reference to oldPq. Shutdown does not Delete at all — the whole +// Bifrost instance is torn down. +// In all cases, once no producer goroutine holds a reference to the +// ProviderQueue, both the struct and pq.queue are eligible for GC. +// No explicit close is needed. type ProviderQueue struct { - queue chan *ChannelMessage // the actual request queue channel - done chan struct{} // closed to signal shutdown to producers + queue chan *ChannelMessage // the actual request queue channel — never closed, see above + done chan struct{} // closed by signalClosing() to signal shutdown; never written to otherwise closing uint32 // atomic: 0 = open, 1 = closing signalOnce sync.Once - closeOnce sync.Once } func isLargePayloadPassthrough(ctx *schemas.BifrostContext) bool { @@ -122,14 +152,6 @@ func (pq *ProviderQueue) signalClosing() { }) } -// closeQueue closes the provider queue. -// Protected by sync.Once to prevent double-close. -func (pq *ProviderQueue) closeQueue() { - pq.closeOnce.Do(func() { - close(pq.queue) - }) -} - // isClosing returns true if the provider queue is closing. // Uses atomic load for lock-free checking. func (pq *ProviderQueue) isClosing() bool { @@ -153,6 +175,9 @@ type PluginPipeline struct { postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) chunkCount int + + // Plugin logging: cached scoped contexts for streaming post-hooks (reused across chunks) + streamScopedCtxs map[string]*schemas.BifrostContext } // pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks @@ -221,7 +246,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { bifrost.dropExcessRequests.Store(config.DropExcessRequests) if bifrost.keySelector == nil { - bifrost.keySelector = WeightedRandomKeySelector + bifrost.keySelector = keyselectors.WeightedRandom } // Initialize object pools @@ -592,9 +617,10 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * Message: "prompt not provided for text completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -607,7 +633,7 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req * if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TextCompletionResponse, nil } @@ -631,9 +657,10 @@ func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "text not provided for text completion stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TextCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -662,9 +689,10 @@ func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, r Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -727,9 +755,10 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -763,9 +792,10 @@ func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *s Message: "responses not provided for responses request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -831,9 +861,10 @@ func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req Message: "responses not provided for responses stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ResponsesStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -866,9 +897,10 @@ func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *sch Message: "input not provided for count tokens request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.CountTokensRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.CountTokensRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -898,16 +930,19 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem }, } } - if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !isLargePayloadPassthrough(ctx) { + hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil && + (req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil) + if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "embedding input not provided for embedding request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.EmbeddingRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -920,7 +955,7 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schem if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.EmbeddingResponse, nil } @@ -944,9 +979,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "query not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -957,9 +993,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: "documents not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -971,9 +1008,10 @@ func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas. Message: fmt.Sprintf("document text is empty at index %d", i), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.RerankRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.RerankRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1009,9 +1047,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document type not provided for ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1022,9 +1061,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "document_url not provided for document_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1035,9 +1075,10 @@ func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.Bif Message: "image_url not provided for image_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.OCRRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.OCRRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1072,9 +1113,10 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. Message: "speech input not provided for speech request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1087,7 +1129,7 @@ func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas. if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.SpeechResponse, nil } @@ -1111,9 +1153,10 @@ func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *sc Message: "speech input not provided for speech stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.SpeechStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1145,9 +1188,10 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s Message: "transcription input not provided for transcription request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1160,7 +1204,7 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *s if err != nil { return nil, err } - //TODO: Release the response + // TODO: Release the response return response.TranscriptionResponse, nil } @@ -1184,9 +1228,10 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, Message: "transcription input not provided for transcription stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.TranscriptionStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1200,7 +1245,8 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, // ImageGenerationRequest sends an image generation request to the specified provider. func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1219,9 +1265,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for image generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1241,9 +1288,10 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1253,7 +1301,8 @@ func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, // ImageGenerationStreamRequest sends an image generation stream request to the specified provider. func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + req *schemas.BifrostImageGenerationRequest, +) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1272,9 +1321,10 @@ func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext Message: "prompt not provided for image generation stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageGenerationStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1306,14 +1356,19 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "images not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } - // Prompt is not required when type is background_removal - if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") && + // Prompt is not required for certain operation types that work without a text prompt + var imageEditParamsType *string + if req.Params != nil { + imageEditParamsType = req.Params.Type + } + if !isPromptOptionalImageEditType(imageEditParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1321,9 +1376,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "prompt not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1344,9 +1400,10 @@ func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schem Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1374,14 +1431,19 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "images not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } - // Prompt is not required when type is background_removal - if (req.Params == nil || req.Params.Type == nil || *req.Params.Type != "background_removal") && + // Prompt is not required for certain operation types that work without a text prompt + var imageEditStreamParamsType *string + if req.Params != nil { + imageEditStreamParamsType = req.Params.Type + } + if !isPromptOptionalImageEditType(imageEditStreamParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1389,9 +1451,10 @@ func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req Message: "prompt not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageEditStreamRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1423,9 +1486,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "image not provided for image variation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1446,9 +1510,10 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ImageVariationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.ImageVariationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1458,7 +1523,8 @@ func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req * // VideoGenerationRequest sends a video generation request to the specified provider. func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, - req *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { + req *schemas.BifrostVideoGenerationRequest, +) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1477,9 +1543,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "prompt not provided for video generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -1499,9 +1566,10 @@ func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: req.Provider, - ModelRequested: req.Model, + RequestType: schemas.VideoGenerationRequest, + Provider: req.Provider, + OriginalModelRequested: req.Model, + ResolvedModelUsed: req.Model, }, } } @@ -3109,57 +3177,36 @@ func (bifrost *Bifrost) RemoveProvider(providerKey schemas.ModelProvider) error } pq := pqValue.(*ProviderQueue) - // Step 2: Signal closing to producers (prevents new sends) - // This must happen before closing the queue to avoid "send on closed channel" panics + // Step 2: Signal closing. Blocks new producers (isClosing() returns true) and + // causes idle workers to drain remaining buffered requests with errors then exit. pq.signalClosing() bifrost.logger.Debug("signaled closing for provider %s", providerKey) - // Step 3: Now safe to close the queue (no new producers can send) - pq.closeQueue() - bifrost.logger.Debug("closed request queue for provider %s", providerKey) - - // Step 4: Wait for all workers to finish processing in-flight requests + // Step 3: Wait for all workers to finish in-flight requests and exit. waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) } - // Step 5: Remove the provider from the request queues + // Step 3b: Final drain sweep — see drainQueueWithErrors for full explanation. + bifrost.drainQueueWithErrors(pq) + + // Step 4: Remove the provider from the request queues. bifrost.requestQueues.Delete(providerKey) - // Step 6: Remove the provider from the wait groups + // Step 5: Remove the provider from the wait groups. bifrost.waitGroups.Delete(providerKey) - // Step 7: Remove the provider from the providers slice - replacementAttempts := 0 - maxReplacementAttempts := 100 // Prevent infinite loops in high-contention scenarios - for { - replacementAttempts++ - if replacementAttempts > maxReplacementAttempts { - return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts", providerKey, maxReplacementAttempts) - } - oldPtr := bifrost.providers.Load() - var oldSlice []schemas.Provider - if oldPtr != nil { - oldSlice = *oldPtr - } - // Create new slice without the old provider of this key - // Use exact capacity to avoid allocations - if len(oldSlice) == 0 { - return fmt.Errorf("provider %s not found in providers slice", providerKey) - } - newSlice := make([]schemas.Provider, 0, len(oldSlice)-1) - for _, existingProvider := range oldSlice { - if existingProvider.GetProviderKey() != providerKey { - newSlice = append(newSlice, existingProvider) - } - } - if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { - bifrost.logger.Debug("successfully removed provider instance for %s in providers slice", providerKey) - break - } - // Retrying as swapping did not work (likely due to concurrent modification) + // Step 6: Remove the provider from the providers slice. + if err := bifrost.removeProviderFromSlice(providerKey); err != nil { + bifrost.logger.Error( + "provider %s was removed from queues but could not be removed from the providers slice — "+ + "bifrost.providers is now inconsistent. "+ + "To recover: retry RemoveProvider(%s), or restart Bifrost if that fails.", + providerKey, providerKey, + ) + return err } bifrost.logger.Info("successfully removed provider %s", providerKey) @@ -3181,6 +3228,15 @@ func (bifrost *Bifrost) RemoveProvider(providerKey schemas.ModelProvider) error // Note: This operation will temporarily pause request processing for the specified provider // while the transition occurs. In-flight requests will complete before workers are stopped. // Buffered requests in the old queue will be transferred to the new queue to prevent loss. +// +// Concurrency safety — no-worker window: +// UpdateProvider holds a per-provider write lock (providerMutex.Lock) for its entire +// duration. All producer paths (tryRequest, tryStreamRequest) acquire the corresponding +// read lock inside getProviderQueue before they can look up or enqueue into any queue. +// This means no producer can observe or enqueue into newPq until UpdateProvider returns +// and releases the write lock — at which point new workers are already running and +// consuming newPq. There is therefore no window where newPq is visible to producers +// but has zero workers. func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error { bifrost.logger.Info(fmt.Sprintf("Updating provider configuration for provider %s", providerKey)) // Get the updated configuration from the account @@ -3213,23 +3269,23 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error queue: make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize), done: make(chan struct{}), signalOnce: sync.Once{}, - closeOnce: sync.Once{}, } - // Step 2: Atomically replace the queue FIRST (new producers immediately get the new queue) - // This minimizes the window where requests fail during the update + // Step 2: Atomically replace the queue so new producers immediately use newPq. bifrost.requestQueues.Store(providerKey, newPq) bifrost.logger.Debug("stored new queue for provider %s, new producers will use it", providerKey) - // Step 3: Signal old queue is closing to producers that already have a reference - // Only in-flight producers with the old reference will see this - oldPq.signalClosing() - bifrost.logger.Debug("signaled closing for old queue of provider %s", providerKey) - - // Step 4: Transfer any buffered requests from old queue to new queue - // This prevents request loss during the transition + // Step 3: Transfer buffered requests from the old queue to the new queue BEFORE + // signalling workers to stop. This ensures buffered requests are processed by the + // new workers rather than being drained with errors. + // Old workers are still running and may consume some items concurrently — that is + // fine, they process them normally. + // If newPq is full during transfer, all remaining buffered requests are cancelled + // immediately rather than blocking — this avoids the deadlock where transfer goroutines + // wait for space that only opens once new workers start (which can't happen until + // the transfer completes). transferredCount := 0 - var transferWaitGroup sync.WaitGroup + cancelledCount := 0 for { select { case msg := <-oldPq.queue: @@ -3237,37 +3293,33 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error case newPq.queue <- msg: transferredCount++ default: - // New queue is full, handle this request in a goroutine - // This is unlikely with proper buffer sizing but provides safety - transferWaitGroup.Add(1) - go func(m *ChannelMessage) { - defer transferWaitGroup.Done() + // newPq is full — cancel this message and all remaining in oldPq. + cancelMsg := func(r *ChannelMessage) { + prov, mod, _ := r.BifrostRequest.GetRequestFields() select { - case newPq.queue <- m: - // Message successfully transferred - case <-time.After(5 * time.Second): - bifrost.logger.Warn("Failed to transfer buffered request to new queue within timeout") - // Send error response to avoid hanging the client - provider, model, _ := m.BifrostRequest.GetRequestFields() - select { - case m.Err <- schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Message: "request failed during provider concurrency update", - }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: m.RequestType, - Provider: provider, - ModelRequested: model, - }, - }: - case <-time.After(1 * time.Second): - // If we can't send the error either, just log and continue - bifrost.logger.Warn("Failed to send error response during transfer timeout") - } + case r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{Message: "request failed during provider concurrency update: queue full"}, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: prov, + OriginalModelRequested: mod, + }, + }: + case <-r.Context.Done(): + } + } + cancelMsg(msg) + cancelledCount++ + for { + select { + case r := <-oldPq.queue: + cancelMsg(r) + cancelledCount++ + default: + goto transferComplete } - }(msg) - goto transferComplete + } } default: // No more buffered messages @@ -3276,33 +3328,59 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error } transferComplete: - // Wait for all transfer goroutines to complete - transferWaitGroup.Wait() if transferredCount > 0 { bifrost.logger.Info("transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey) } + if cancelledCount > 0 { + bifrost.logger.Warn("cancelled %d buffered requests during transfer for provider %s: new queue was full", cancelledCount, providerKey) + } - // Step 5: Close the old queue to signal workers to stop - oldPq.closeQueue() - bifrost.logger.Debug("closed old request queue for provider %s", providerKey) + // Step 4: Signal the old queue is closing. Producers that still hold a reference to + // oldPq will detect this via isClosing() and transparently re-route to newPq. + // This happens after the transfer so the new queue is already populated before + // stale producers attempt their re-route. + oldPq.signalClosing() + bifrost.logger.Debug("signaled closing for old queue of provider %s", providerKey) - // Step 6: Wait for all existing workers to finish processing in-flight requests + // Step 5: Wait for all existing workers to finish processing in-flight requests. + // Workers exit via oldPq.done (signalled above). waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) } - // Step 7: Create new wait group for the updated workers + // Step 5b: Final drain sweep — see drainQueueWithErrors for full explanation. + bifrost.drainQueueWithErrors(oldPq) + + // Step 6: Create new wait group for the updated workers. bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) - // Step 8: Create provider instance + // Step 7: Create provider instance. provider, err := bifrost.createBaseProvider(providerKey, providerConfig) if err != nil { - return fmt.Errorf("failed to create provider instance for %s: %v", providerKey, err) - } - - // Step 8.5: Atomically replace the provider in the providers slice + // Roll back: signal closing, remove from map, then drain. + // Order matters: Delete before drainQueueWithErrors so that producers + // re-routing via requestQueues.Load find nothing and return "provider + // shutting down" immediately, narrowing the TOCTOU window before the sweep. + newPq.signalClosing() + bifrost.requestQueues.Delete(providerKey) + bifrost.waitGroups.Delete(providerKey) + bifrost.drainQueueWithErrors(newPq) + if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil { + bifrost.logger.Error( + "UpdateProvider rollback for %s is incomplete — provider was removed from queues "+ + "but could not be removed from the providers slice: %v. "+ + "bifrost.providers is now inconsistent. "+ + "To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+ + "or restart Bifrost if that fails.", + providerKey, sliceErr, providerKey, + ) + } + return fmt.Errorf("provider update for %s failed during initialization; provider has been removed — re-add or retry UpdateProvider to restore it: %v", providerKey, err) + } + + // Step 8: Atomically replace the provider in the providers slice. // This must happen before starting new workers to prevent stale reads bifrost.logger.Debug("atomically replacing provider instance in providers slice for %s", providerKey) @@ -3312,7 +3390,21 @@ transferComplete: for { replacementAttempts++ if replacementAttempts > maxReplacementAttempts { - return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts", providerKey, maxReplacementAttempts) + newPq.signalClosing() + bifrost.requestQueues.Delete(providerKey) + bifrost.waitGroups.Delete(providerKey) + bifrost.drainQueueWithErrors(newPq) + if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil { + bifrost.logger.Error( + "UpdateProvider rollback for %s is incomplete — provider was removed from queues "+ + "but could not be removed from the providers slice: %v. "+ + "bifrost.providers is now inconsistent. "+ + "To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+ + "or restart Bifrost if that fails.", + providerKey, sliceErr, providerKey, + ) + } + return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts; provider has been removed — re-add or retry UpdateProvider to restore it", providerKey, maxReplacementAttempts) } oldPtr := bifrost.providers.Load() @@ -3348,7 +3440,7 @@ transferComplete: // Retrying as swapping did not work (likely due to concurrent modification) } - // Step 9: Start new workers with updated concurrency + // Step 9: Start new workers with updated concurrency. bifrost.logger.Debug("starting %d new workers for provider %s with buffer size %d", providerConfig.ConcurrencyAndBufferSize.Concurrency, providerKey, @@ -3384,6 +3476,33 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn return mutexValue.(*sync.RWMutex) } +// removeProviderFromSlice atomically removes the provider with the given key +// from bifrost.providers using a CAS retry loop. Callers hold the per-provider +// write mutex so no concurrent goroutine can re-add this key — contention is +// only from other providers' CAS operations, so the loop converges in at most +// a few iterations under any concurrency level. +// Returns an error if the limit is hit (state will be inconsistent). +func (bifrost *Bifrost) removeProviderFromSlice(providerKey schemas.ModelProvider) error { + const maxAttempts = 100 + for range maxAttempts { + oldPtr := bifrost.providers.Load() + if oldPtr == nil { + return nil + } + oldSlice := *oldPtr + newSlice := make([]schemas.Provider, 0, len(oldSlice)) + for _, p := range oldSlice { + if p.GetProviderKey() != providerKey { + newSlice = append(newSlice, p) + } + } + if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { + return nil + } + } + return fmt.Errorf("failed to remove provider %s from providers slice after %d attempts", providerKey, maxAttempts) +} + // MCP PUBLIC API // RegisterMCPTool registers a typed tool handler with the MCP integration. @@ -3411,7 +3530,7 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn // }, toolSchema) func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RegisterTool(name, description, handler, toolSchema) @@ -3429,7 +3548,7 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // - error: Any retrieval error func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { if bifrost.MCPManager == nil { - return nil, fmt.Errorf("MCP is not configured in this Bifrost instance") + return nil, fmt.Errorf("mcp is not configured in this bifrost instance") } clients := bifrost.MCPManager.GetClients() @@ -3469,7 +3588,7 @@ func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { // // Returns: // - []schemas.ChatTool: List of available tools -func (bifrost *Bifrost) GetAvailableMCPTools(ctx context.Context) []schemas.ChatTool { +func (bifrost *Bifrost) GetAvailableMCPTools(ctx *schemas.BifrostContext) []schemas.ChatTool { if bifrost.MCPManager == nil { return nil } @@ -3540,7 +3659,7 @@ func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error { // } func (bifrost *Bifrost) RemoveMCPClient(id string) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RemoveClient(id) @@ -3548,11 +3667,31 @@ func (bifrost *Bifrost) RemoveMCPClient(id string) error { // SetMCPManager sets the MCP manager for this Bifrost instance. // This allows injecting a custom MCP manager implementation (e.g., for enterprise features). +// If the provided manager is a concrete *mcp.MCPManager, Bifrost's plugin pipeline is injected +// into the manager's CodeMode so that nested tool calls run through the plugin hooks. // // Parameters: // - manager: The MCP manager to set (must implement MCPManagerInterface) func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { bifrost.MCPManager = manager + // Inject Bifrost's plugin pipeline into the manager's CodeMode so that + // nested tool calls (e.g. via Starlark executeCode) run through plugin hooks. + if m, ok := manager.(*mcp.MCPManager); ok { + m.SetPluginPipeline( + func() mcp.PluginPipeline { + pipeline := bifrost.getPluginPipeline() + if pp, ok := any(pipeline).(mcp.PluginPipeline); ok { + return pp + } + return nil + }, + func(pipeline mcp.PluginPipeline) { + if pp, ok := pipeline.(*PluginPipeline); ok { + bifrost.releasePluginPipeline(pp) + } + }, + ) + } } // UpdateMCPClient updates the MCP client. @@ -3573,7 +3712,7 @@ func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { // }) func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.UpdateClient(id, updatedConfig) @@ -3588,23 +3727,63 @@ func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPCli // - error: Any reconnection error func (bifrost *Bifrost) ReconnectMCPClient(id string) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.ReconnectClient(id) } +// VerifyPerUserOAuthConnection delegates to the MCP manager to verify an MCP +// server using a temporary access token and discover available tools. The +// connection is closed after verification. If the MCP manager is not yet +// initialized, it is lazily created (same as AddMCPClient). +func (bifrost *Bifrost) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + // Ensure MCP manager is initialized (lazy init, same pattern as AddMCPClient) + if bifrost.MCPManager == nil { + bifrost.mcpInitOnce.Do(func() { + mcpConfig := schemas.MCPConfig{ + ClientConfigs: []*schemas.MCPClientConfig{}, + } + mcpConfig.PluginPipelineProvider = func() interface{} { + return bifrost.getPluginPipeline() + } + mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) { + if pp, ok := pipeline.(*PluginPipeline); ok { + bifrost.releasePluginPipeline(pp) + } + } + codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger) + bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) + }) + } + if bifrost.MCPManager == nil { + return nil, nil, fmt.Errorf("MCP manager is not initialized") + } + return bifrost.MCPManager.VerifyPerUserOAuthConnection(ctx, config, accessToken) +} + +// SetClientTools delegates to the MCP manager to update the tool map for an +// existing MCP client. +func (bifrost *Bifrost) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + if bifrost.MCPManager != nil { + bifrost.MCPManager.SetClientTools(clientID, tools, toolNameMapping) + } +} + // UpdateToolManagerConfig updates the tool manager config for the MCP manager. // This allows for hot-reloading of the tool manager config at runtime. -func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string) error { +// Pass the current value of disableAutoToolInject whenever only other fields +// change so the flag is never silently reset to its zero value. +func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error { if bifrost.MCPManager == nil { - return fmt.Errorf("MCP is not configured in this Bifrost instance") + return fmt.Errorf("mcp is not configured in this bifrost instance") } bifrost.MCPManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ - MaxAgentDepth: maxAgentDepth, - ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, - CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + MaxAgentDepth: maxAgentDepth, + ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, + CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), + DisableAutoToolInject: disableAutoToolInject, }) return nil } @@ -3694,7 +3873,6 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi queue: make(chan *ChannelMessage, config.ConcurrencyAndBufferSize.BufferSize), done: make(chan struct{}), signalOnce: sync.Once{}, - closeOnce: sync.Once{}, } bifrost.requestQueues.Store(providerKey, pq) @@ -3787,9 +3965,10 @@ func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) sche return bifrost.getProviderByKey(providerKey) } -// SelectKeyForProvider selects an API key for the given provider and model. -// Used by WebSocket handlers that need a key for upstream connections. -func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { +// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model. +// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific +// AllowedRequests gates such as realtime-only support. +func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { if ctx == nil { ctx = bifrost.ctx } @@ -3798,7 +3977,17 @@ func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, provid config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" { baseProvider = config.CustomProviderConfig.BaseProviderType } - return bifrost.selectKeyFromProviderForModel(ctx, schemas.WebSocketResponsesRequest, providerKey, model, baseProvider) + supportedKeys, _, err := bifrost.selectKeyFromProviderForModelWithPool(ctx, requestType, providerKey, model, baseProvider) + if err != nil { + return schemas.Key{}, err + } + if len(supportedKeys) == 0 { + return schemas.Key{}, nil + } + if len(supportedKeys) == 1 { + return supportedKeys[0], nil + } + return bifrost.keySelector(ctx, supportedKeys, providerKey, model) } // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. @@ -3812,6 +4001,13 @@ type WSStreamHooks struct { ShortCircuitResponse *schemas.BifrostResponse } +// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a +// single realtime turn rather than one long-lived transport connection. +type RealtimeTurnHooks struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() +} + // RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks, // and returns a PostHookRunner for per-chunk post-processing. // Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks. @@ -3850,12 +4046,22 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() - return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3864,6 +4070,10 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3875,8 +4085,21 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } } + wsProvider, wsModel, _ := preReq.GetRequestFields() postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - return pipeline.RunPostLLMHooks(ctx, result, err, preCount) + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + if result != nil { + result.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) + } + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + return resp, bifrostErr } return &WSStreamHooks{ @@ -3885,6 +4108,94 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche }, nil } +// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for +// a single realtime turn. Unlike generic stream hooks, realtime turns do not +// support short-circuit responses in v1 because the transports cannot yet emit a +// fully synthetic assistant turn without an upstream generation. +func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) { + if req == nil { + bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil") + bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest + return nil, bifrostErr + } + if ctx == nil { + ctx = bifrost.ctx + } + + if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { + ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) + } + + tracer := bifrost.getTracer() + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + + if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { + traceID := tracer.CreateTrace("") + if traceID != "" { + ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) + } + } + + pipeline := bifrost.getPluginPipeline() + cleanup := func() { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.CleanupStreamAccumulator(traceID) + } + bifrost.releasePluginPipeline(pipeline) + } + provider, model, _ := req.GetRequestFields() + + preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) + if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + if shortCircuit != nil { + if shortCircuit.Error != nil { + shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + if bifrostErr != nil { + return nil, bifrostErr + } + return nil, shortCircuit.Error + } + if shortCircuit.Response != nil { + // Short-circuit responses are not supported for realtime turns (v1). + // Treat this like an error turn so plugins can close pending state cleanly. + bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + } + + return &RealtimeTurnHooks{ + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + drainAndAttachPluginLogs(ctx) + return resp, bifrostErr + }, + Cleanup: cleanup, + }, nil +} + // getProviderByKey retrieves a provider instance from the providers array by its provider key. // Returns the provider if found, or nil if no provider with the given key exists. func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { @@ -4087,11 +4398,7 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) return nil, err } @@ -4124,16 +4431,6 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4176,29 +4473,10 @@ func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas. // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4213,11 +4491,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { - err.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + err.PopulateExtraFields(req.RequestType, provider, model, model) err.StatusCode = schemas.Ptr(fasthttp.StatusBadRequest) return nil, err } @@ -4239,16 +4513,6 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } return primaryResult, primaryErr } @@ -4289,29 +4553,10 @@ func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *sc // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { - fallbackErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: fallback.Provider, - ModelRequested: fallback.Model, - RawRequest: fallbackErr.ExtraFields.RawRequest, - RawResponse: fallbackErr.ExtraFields.RawResponse, - KeyStatuses: fallbackErr.ExtraFields.KeyStatuses, - } return nil, fallbackErr } } - if primaryErr != nil { - primaryErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - RawRequest: primaryErr.ExtraFields.RawRequest, - RawResponse: primaryErr.ExtraFields.RawResponse, - KeyStatuses: primaryErr.ExtraFields.KeyStatuses, - } - } - // All providers failed, return the original error return nil, primaryErr } @@ -4323,11 +4568,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4338,7 +4579,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling requestHandler, so streaming goroutines @@ -4355,7 +4598,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // Handle short-circuit with response (success case) if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4363,7 +4608,9 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif // Handle short-circuit with error if shortCircuit.Error != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return resp, nil @@ -4371,28 +4618,38 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx - // Check if provider is closing before attempting to send (lock-free atomic check) - // This prevents "send on closed channel" panics during provider removal/update + // If the queue is closing, check whether the provider was updated (new queue + // available) or removed. On update, transparently re-route to the new queue + // so in-flight producers don't get spurious errors. On removal, error out. + // + // Use a direct sync.Map lookup instead of getProviderQueue to avoid the + // lazy-creation path: getProviderQueue can resurrect a provider that was + // just removed by RemoveProvider if the account config still exists. if pq.isClosing() { - bifrost.releaseChannelMessage(msg) - bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, + var reroutedPq *ProviderQueue + if val, ok := bifrost.requestQueues.Load(provider); ok { + if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() { + reroutedPq = candidate + } } - return nil, bifrostErr + if reroutedPq == nil { + bifrost.releaseChannelMessage(msg) + bifrostErr := newBifrostErrorFromMsg("provider is shutting down") + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + OriginalModelRequested: model, + } + return nil, bifrostErr + } + pq = reroutedPq } // Use select with done channel to detect shutdown during send @@ -4402,36 +4659,26 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4440,15 +4687,13 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4458,33 +4703,52 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif select { case result = <-msg.Response: resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount) + drainAndAttachPluginLogs(msg.Context) if bifrostErr != nil { bifrost.releaseChannelMessage(msg) return nil, bifrostErr } bifrost.releaseChannelMessage(msg) - // Checking if need to drop raw messages - // This we use for requests like containers, container files, skills etc. - if drop, ok := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool); ok && drop && resp != nil { - extraField := resp.GetExtraFields() - extraField.RawRequest = nil - extraField.RawResponse = nil + // Strip raw fields that were captured for logging but should not reach the client. + if resp != nil { + dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) + dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) + if dropReq || dropResp { + extraField := resp.GetExtraFields() + if dropReq { + extraField.RawRequest = nil + } + if dropResp { + extraField.RawResponse = nil + } + } } return resp, nil case bifrostErrVal := <-msg.Err: bifrostErrPtr := &bifrostErrVal resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount) + drainAndAttachPluginLogs(msg.Context) bifrost.releaseChannelMessage(msg) - // Drop raw request/response on error path too - if drop, ok := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool); ok && drop { + // Strip raw fields on error path too. + dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) + dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) + if dropReq || dropResp { if bifrostErrPtr != nil { - bifrostErrPtr.ExtraFields.RawRequest = nil - bifrostErrPtr.ExtraFields.RawResponse = nil + if dropReq { + bifrostErrPtr.ExtraFields.RawRequest = nil + } + if dropResp { + bifrostErrPtr.ExtraFields.RawResponse = nil + } } if resp != nil { extraField := resp.GetExtraFields() - extraField.RawRequest = nil - extraField.RawResponse = nil + if dropReq { + extraField.RawRequest = nil + } + if dropResp { + extraField.RawResponse = nil + } } } if bifrostErrPtr != nil { @@ -4492,9 +4756,17 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } return resp, nil case <-ctx.Done(): - bifrost.releaseChannelMessage(msg) + // Do NOT releaseChannelMessage here. The message is already enqueued and + // the worker still holds a reference to msg.Response and msg.Err. Returning + // those channels to the pool now would let the next request reuse them while + // the worker is still writing to them — stale data corruption. The worker + // never calls releaseChannelMessage itself, so this message leaks from the + // pool and is GC'd. That is intentional: a small pool leak on cancellation + // is far safer than corrupting another request's channels. provider, model, _ := req.GetRequestFields() - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "waiting for provider response") + bifrostErr := newBifrostCtxDoneError(ctx, "waiting for provider response") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4505,11 +4777,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } @@ -4520,7 +4788,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem tracer := bifrost.getTracer() if tracer == nil { - return nil, newBifrostErrorFromMsg("tracer not found in context") + bifrostErr := newBifrostErrorFromMsg("tracer not found in context") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } // Store tracer in context BEFORE calling RunLLMPreHooks, so plugins and streaming goroutines @@ -4541,14 +4811,21 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } pipeline := bifrost.getPluginPipeline() - defer bifrost.releasePluginPipeline(pipeline) + releasePipeline := true + defer func() { + if releasePipeline { + bifrost.releasePluginPipeline(pipeline) + } + }() preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4556,13 +4833,23 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Handle short-circuit with stream if shortCircuit.Stream != nil { outputStream := make(chan *schemas.BifrostStreamChunk) + releasePipeline = false // pipeline is released inside the goroutine after stream drains // Create a post hook runner cause pipeline object is put back in the pool on defer pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - return pipeline.RunPostLLMHooks(ctx, result, err, preCount) + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + return resp, bifrostErr } go func() { + defer func() { + drainAndAttachPluginLogs(ctx) // ensure logs are drained even if stream closes without a final chunk + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }() defer close(outputStream) for streamMsg := range shortCircuit.Stream { @@ -4598,10 +4885,18 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // shared processedResponse or processedError objects. streamResponse := providerUtils.BuildClientStreamChunk(ctx, processedResponse, processedError) - // Send the processed message to the output stream - outputStream <- streamResponse + // Guarded send: if the consumer abandons outputStream (client + // disconnect, ctx cancel), drain the upstream shortCircuit.Stream + // so its producer can exit cleanly instead of blocking on its send. + select { + case outputStream <- streamResponse: + case <-ctx.Done(): + for range shortCircuit.Stream { + } + return + } - //TODO: Release the processed response immediately after use + // TODO: Release the processed response immediately after use } }() @@ -4610,7 +4905,9 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Handle short-circuit with error if shortCircuit.Error != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } return newBifrostMessageChan(resp), nil @@ -4618,28 +4915,38 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx - // Check if provider is closing before attempting to send (lock-free atomic check) - // This prevents "send on closed channel" panics during provider removal/update + // If the queue is closing, check whether the provider was updated (new queue + // available) or removed. On update, transparently re-route to the new queue + // so in-flight producers don't get spurious errors. On removal, error out. + // + // Use a direct sync.Map lookup instead of getProviderQueue to avoid the + // lazy-creation path: getProviderQueue can resurrect a provider that was + // just removed by RemoveProvider if the account config still exists. if pq.isClosing() { - bifrost.releaseChannelMessage(msg) - bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, + var reroutedPq *ProviderQueue + if val, ok := bifrost.requestQueues.Load(provider); ok { + if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() { + reroutedPq = candidate + } } - return nil, bifrostErr + if reroutedPq == nil { + bifrost.releaseChannelMessage(msg) + bifrostErr := newBifrostErrorFromMsg("provider is shutting down") + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + OriginalModelRequested: model, + } + return nil, bifrostErr + } + pq = reroutedPq } // Use select with done channel to detect shutdown during send @@ -4649,36 +4956,26 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { @@ -4687,15 +4984,13 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, - } + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for queue space") + bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") + bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) + return nil, bifrostErr } } @@ -4713,6 +5008,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // On error we will complete post-hooks recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load())) + drainAndAttachPluginLogs(ctx) bifrost.releaseChannelMessage(msg) if recoveredErr != nil { return nil, recoveredErr @@ -4721,16 +5017,28 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem return newBifrostMessageChan(recoveredResp), nil } return nil, &bifrostErrVal + case <-ctx.Done(): + // Do NOT releaseChannelMessage here — see the identical note in tryRequest. + // Worker still holds msg.ResponseStream/msg.Err; releasing now corrupts the + // next request that reuses those pooled channels. + return nil, newBifrostCtxDoneError(ctx, "while waiting for stream response") } } -// executeRequestWithRetries is a generic function that handles common request processing logic -// It consolidates retry logic, backoff calculation, and error handling -// It is not a bifrost method because interface methods in go cannot be generic +// executeRequestWithRetries is a generic function that handles common request processing logic. +// It consolidates retry logic, backoff calculation, error handling, and key rotation. +// It is not a bifrost method because interface methods in go cannot be generic. +// +// keyProvider, when non-nil, is called on the first attempt and again whenever a rate-limit error +// triggers a key rotation. It receives the set of key IDs already used in the current rotation +// cycle so it can exclude them; when the pool is exhausted the provider resets the set and starts +// a fresh weighted round. Network errors (5xx) reuse the same key since they are transient server +// issues rather than per-key capacity problems. func executeRequestWithRetries[T any]( ctx *schemas.BifrostContext, config *schemas.ProviderConfig, - requestHandler func() (T, *schemas.BifrostError), + requestHandler func(key schemas.Key) (T, *schemas.BifrostError), + keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error), requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, @@ -4741,8 +5049,77 @@ func executeRequestWithRetries[T any]( var bifrostError *schemas.BifrostError var attempts int + var currentKey schemas.Key + var usedKeyIDs map[string]bool + lastWasRateLimit := false + for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { ctx.SetValue(schemas.BifrostContextKeyNumberOfRetries, attempts) + + // Reset the trail on the first attempt so a reused or shared context (bifrost.ctx) + // doesn't carry over records from a previous request. + if keyProvider != nil && attempts == 0 { + ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, []schemas.KeyAttemptRecord{}) + } + + // Select / rotate key: always on attempt 0, and again when the previous failure was a + // rate-limit (different key may have remaining capacity). Network errors keep the same key. + if keyProvider != nil && (attempts == 0 || lastWasRateLimit) { + if usedKeyIDs == nil { + usedKeyIDs = make(map[string]bool) + } + + // Wrap key selection in a dedicated span so traces show which key was chosen + // (and when rotation happened). The span is opened before keyProvider is called + // so selection errors are captured too. + keyTracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + var keySpanCtx context.Context + var keyHandle schemas.SpanHandle + if keyTracer != nil { + keySpanCtx, keyHandle = keyTracer.StartSpan(ctx, "key.selection", schemas.SpanKindInternal) + keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(providerKey)) + keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) + if attempts > 0 { + keyTracer.SetAttribute(keyHandle, "retry.count", attempts) + } + } + + selectedKey, err := keyProvider(usedKeyIDs) + + if keyTracer != nil { + if err != nil { + keyTracer.SetAttribute(keyHandle, "error", err.Error()) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) + } else { + keyTracer.SetAttribute(keyHandle, "key.id", selectedKey.ID) + keyTracer.SetAttribute(keyHandle, "key.name", selectedKey.Name) + keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") + // Propagate the span context so subsequent spans (llm.call / retry.attempt.N) + // are correctly linked in the trace hierarchy. + ctx.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID)) + } + } + + if err != nil { + var zero T + return zero, newBifrostErrorFromMsg(err.Error()) + } + currentKey = selectedKey + ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, currentKey.ID) + ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, currentKey.Name) + } + + // Append a trail record for every attempt (key rotation and same-key retries alike). + // Skipped when keyProvider is nil (keyless providers have no key to track). + // FailReason is populated below once the attempt outcome is known. + if keyProvider != nil { + schemas.AppendToContextList(ctx, schemas.BifrostContextKeyAttemptTrail, schemas.KeyAttemptRecord{ + Attempt: attempts, + KeyID: currentKey.ID, + KeyName: currentKey.Name, + }) + } + if attempts > 0 { // Log retry attempt var retryMsg string @@ -4834,7 +5211,7 @@ func executeRequestWithRetries[T any]( } // Attempt the request - result, bifrostError = requestHandler() + result, bifrostError = requestHandler(currentKey) // For streaming requests that returned success, check if the first chunk // is actually an error (e.g., rate limits sent as SSE events in HTTP 200). @@ -4842,7 +5219,7 @@ func executeRequestWithRetries[T any]( // the SSE stream instead of returning proper HTTP error status codes. if bifrostError == nil { if streamChan, ok := any(result).(chan *schemas.BifrostStreamChunk); ok { - checkedStream, drainDone, firstChunkErr := providerUtils.CheckFirstStreamChunkForError(streamChan) + checkedStream, drainDone, firstChunkErr := providerUtils.CheckFirstStreamChunkForError(ctx, streamChan) if firstChunkErr != nil { <-drainDone bifrostError = firstChunkErr @@ -4870,7 +5247,7 @@ func executeRequestWithRetries[T any]( } else { // Populate LLM response attributes for non-streaming responses if resp, ok := any(result).(*schemas.BifrostResponse); ok { - tracer.PopulateLLMResponseAttributes(handle, resp, bifrostError) + tracer.PopulateLLMResponseAttributes(ctx, handle, resp, bifrostError) } // End span with appropriate status @@ -4898,25 +5275,50 @@ func executeRequestWithRetries[T any]( // Check if we should retry based on status code or error message shouldRetry := false - - if bifrostError.Error != nil && (bifrostError.Error.Message == schemas.ErrProviderDoRequest || bifrostError.Error.Message == schemas.ErrProviderNetworkError) { - shouldRetry = true - logger.Debug("detected request HTTP/network error, will retry: %s", bifrostError.Error.Message) - } - - // Retry if status code or error object indicates rate limiting - if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) || + isRateLimit := (bifrostError.StatusCode != nil && *bifrostError.StatusCode == 429) || (bifrostError.Error != nil && (IsRateLimitErrorMessage(bifrostError.Error.Message) || (bifrostError.Error.Type != nil && IsRateLimitErrorMessage(*bifrostError.Error.Type)) || - (bifrostError.Error.Code != nil && IsRateLimitErrorMessage(*bifrostError.Error.Code)))) { + (bifrostError.Error.Code != nil && IsRateLimitErrorMessage(*bifrostError.Error.Code)))) + + errMessage := GetErrorMessage(bifrostError) + + if bifrostError.Error != nil && + (bifrostError.Error.Message == schemas.ErrProviderDoRequest || + bifrostError.Error.Message == schemas.ErrProviderNetworkError) { + shouldRetry = true + logger.Debug("detected request HTTP/network error, will retry: %s", errMessage) + } else if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) || isRateLimit { shouldRetry = true - logger.Debug("detected rate limit error in message, will retry: %s", bifrostError.Error.Message) + logger.Debug("encountered error that should be retried: %s", errMessage) + } + + // Fill FailReason on any failed attempt (retryable or terminal). + // Use the provider error type when present; fall back to "unknown". + if trail, ok := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 { + reason := "unknown" + if bifrostError.Error != nil && bifrostError.Error.Type != nil && *bifrostError.Error.Type != "" { + reason = *bifrostError.Error.Type + } else if isRateLimit { + reason = "rate_limit_error" + } + trail[len(trail)-1].FailReason = &reason + ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, trail) } if !shouldRetry { break } + + // Mark current key as used so the next selection excludes it (rate-limit only). + // Network errors keep the same key — they are transient server issues, not per-key. + if isRateLimit && keyProvider != nil { + if usedKeyIDs == nil { + usedKeyIDs = make(map[string]bool) + } + usedKeyIDs[currentKey.ID] = true + } + lastWasRateLimit = isRateLimit } // Add retry information to error @@ -4924,6 +5326,13 @@ func executeRequestWithRetries[T any]( logger.Debug("request failed after %d %s", attempts, map[bool]string{true: "attempts", false: "attempt"}[attempts > 1]) } + // On final error, clear selected_key so it only reflects a key that actually served a successful response. + // The attempt trail is the authoritative record of which keys were tried. + if bifrostError != nil && keyProvider != nil { + ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, "") + ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, "") + } + return result, bifrostError } @@ -4937,7 +5346,38 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } }() - for req := range pq.queue { + for { + var req *ChannelMessage + select { + case r := <-pq.queue: + req = r + case <-pq.done: + // Provider is shutting down. Drain any buffered requests and send + // back errors so callers are not left blocked on their response channel. + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + select { + case r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is shutting down", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + OriginalModelRequested: mod, + }, + }: + case <-r.Context.Done(): + } + default: + return + } + } + } + _, model, _ := req.BifrostRequest.GetRequestFields() var result *schemas.BifrostResponse @@ -4953,30 +5393,64 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas req.Context.SetValue(schemas.BifrostContextKeyIsCustomProvider, !IsStandardProvider(baseProvider)) // Determine whether this provider attempt should capture raw payloads. - // logging-only mode (store_raw_request_response=true, send_back_raw_*=false): - // sets BifrostContextKeySendBackRaw* = true so providers capture via the unified - // ShouldSendBackRaw* path, and sets BifrostContextKeyRawRequestResponseForLogging - // so the payload is stripped before the response reaches the client. - // full send-back mode (send_back_raw_request/response=true): - // BifrostContextKeySendBackRaw* are set as before; stripping flag stays false. - // Always set both flags explicitly so stale values from a previous provider - // attempt (e.g. first attempt was logging-only, fallback is full send-back) - // cannot leak into the new attempt on a reused context. - existingSendBackReq, _ := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool) - existingSendBackResp, _ := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool) - loggingOnly := config.StoreRawRequestResponse && - !config.SendBackRawRequest && !existingSendBackReq && - !config.SendBackRawResponse && !existingSendBackResp - req.Context.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, loggingOnly) - if loggingOnly { - // Enable capture via the standard flags so ShouldSendBackRaw* needs only one check. - req.Context.SetValue(schemas.BifrostContextKeySendBackRawRequest, true) - req.Context.SetValue(schemas.BifrostContextKeySendBackRawResponse, true) - } - - key := schemas.Key{} + // + // Effective values are computed by merging provider config with any per-request + // context overrides (BifrostContextKeySendBackRawRequest/Response and + // BifrostContextKeyStoreRawRequestResponse). A context value set to either true + // or false fully overrides the provider config for that flag. + // + // Each flag is independent: + // send_back_raw_request — include raw request bytes in the client response. + // send_back_raw_response — include raw response bytes in the client response. + // store_raw_request_response — persist raw bytes in log records (logging plugin only). + // + // Capture is enabled per-side whenever send-back OR store is requested for that side. + // Strip flags tell the response path to remove that side's bytes before the payload + // reaches the caller (used when store=true but send-back=false for that side). + // + // All internal signals are always written explicitly on every attempt so stale values + // from a previous provider attempt (e.g. different fallback provider config) cannot + // leak into the new attempt on a reused context. The user override keys + // (BifrostContextKeySendBackRaw*, BifrostContextKeyStoreRawRequestResponse) are + // never overwritten — they are read-only from bifrost.go's perspective. + + // Step 1: compute effective value for each flag (provider config ← per-request override). + effectiveSendBackReq := config.SendBackRawRequest + if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok { + effectiveSendBackReq = override + } + effectiveSendBackResp := config.SendBackRawResponse + if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok { + effectiveSendBackResp = override + } + effectiveStore := config.StoreRawRequestResponse + if override, ok := req.Context.Value(schemas.BifrostContextKeyStoreRawRequestResponse).(bool); ok { + effectiveStore = override + } + + // Step 2: derive per-side capture and strip flags. + // Capture if we need to send the data back OR store it — independent per side. + captureReq := effectiveSendBackReq || effectiveStore + captureResp := effectiveSendBackResp || effectiveStore + // Strip from client response if we captured for storage but not for send-back. + dropReq := effectiveStore && !effectiveSendBackReq + dropResp := effectiveStore && !effectiveSendBackResp + + // Step 3: write all internal signals explicitly (never touch the user override keys). + req.Context.SetValue(schemas.BifrostContextKeyCaptureRawRequest, captureReq) + req.Context.SetValue(schemas.BifrostContextKeyCaptureRawResponse, captureResp) + req.Context.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, dropReq) + req.Context.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, dropResp) + // Tells the logging plugin whether to persist raw bytes in log records. + req.Context.SetValue(schemas.BifrostContextKeyShouldStoreRawInLogs, effectiveStore) + var keys []schemas.Key - if providerRequiresKey(baseProvider, config.CustomProviderConfig) { + // keyProvider is passed to executeRequestWithRetries to manage key selection and rotation. + // It is nil when no key is required (e.g. providerRequiresKey=false) or for multi-key + // batch/file/container operations that manage their own key lists. + var keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error) + + if providerRequiresKey(config.CustomProviderConfig) { // ListModels needs all enabled/supported keys so providers can aggregate // and report per-key statuses (KeyStatuses). if req.RequestType == schemas.ListModelsRequest { @@ -4990,9 +5464,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue @@ -5019,81 +5494,135 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue } } else { - // Use the custom provider name for actual key selection, but pass base provider type for key validation - // Start span for key selection - keyTracer := bifrost.getTracer() - keySpanCtx, keyHandle := keyTracer.StartSpan(req.Context, "key.selection", schemas.SpanKindInternal) - keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(provider.GetProviderKey())) - keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) - - key, err = bifrost.selectKeyFromProviderForModel(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) - if err != nil { - keyTracer.SetAttribute(keyHandle, "error", err.Error()) - keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) - bifrost.logger.Debug("error selecting key for model %s: %v", model, err) + // Build the key pool for this request. Selection and rotation are deferred to + // executeRequestWithRetries via keyProvider so that each retry attempt can use + // a different key (on rate-limit errors) without re-running the full filtering. + supportedKeys, canRotate, keyPoolErr := bifrost.selectKeyFromProviderForModelWithPool(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) + if keyPoolErr != nil { + bifrost.logger.Debug("error building key pool for model %s: %v", model, keyPoolErr) req.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ - Message: err.Error(), - Error: err, + Message: keyPoolErr.Error(), + Error: keyPoolErr, }, ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + RequestType: req.RequestType, + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } continue } - keyTracer.SetAttribute(keyHandle, "key.id", key.ID) - keyTracer.SetAttribute(keyHandle, "key.name", key.Name) - keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") - // Update context with span ID for subsequent operations - req.Context.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID)) - req.Context.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID) - req.Context.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name) + + if len(supportedKeys) == 0 { + // SkipKeySelection path — keyProvider stays nil, zero Key is used. + } else if !canRotate { + // Fixed key (DirectKey, explicit ID/name, session stickiness): always + // return the same key regardless of usedKeyIDs. + fixedKey := supportedKeys[0] + keyProvider = func(_ map[string]bool) (schemas.Key, error) { + return fixedKey, nil + } + } else { + // Rotating pool: weighted selection with per-cycle exclusion. + // Captures supportedKeys, bifrost.keySelector, provider/model by value. + pool := supportedKeys + provKey := provider.GetProviderKey() + mdl := model + keyProvider = func(usedKeyIDs map[string]bool) (schemas.Key, error) { + available := make([]schemas.Key, 0, len(pool)) + for _, k := range pool { + if !usedKeyIDs[k.ID] { + available = append(available, k) + } + } + if len(available) == 0 { + // All keys exhausted — start a fresh weighted round. + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + available = pool + } + return bifrost.keySelector(req.Context, available, provKey, mdl) + } + } } } } + + originalModelRequested := model + // resolvedModel is set inside the handler closures below on every attempt so that each + // key's own alias mapping is applied. postHookRunner captures resolvedModel by reference + // (Go closure semantics) and will therefore always see the value from the last attempt. + var resolvedModel string + // Create plugin pipeline for streaming requests outside retry loop to prevent leaks var postHookRunner schemas.PostHookRunner var pipeline *PluginPipeline if IsStreamRequestType(req.RequestType) { pipeline = bifrost.getPluginPipeline() postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) + // can read requestType/provider/model from the chunk or error. + // resolvedModel is captured by reference and reflects the alias from the last attempt. + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } + if err != nil { + err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } if bifrostErr != nil { + bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) return nil, bifrostErr + } else if resp != nil { + resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) } return resp, nil } - // Store a finalizer callback to create aggregated post-hook spans at stream end - // This closure captures the pipeline reference and releases it after finalization + // Store a finalizer callback to create aggregated post-hook spans at stream end. + // Wrapped in sync.Once so the normal end-of-stream invocation and a deferred + // safety-net invocation (e.g. from a provider goroutine's panic path) cannot + // double-release the pipeline. + var finalizerOnce sync.Once postHookSpanFinalizer := func(ctx context.Context) { - pipeline.FinalizeStreamingPostHookSpans(ctx) - // Release the pipeline AFTER finalizing spans (not before streaming completes) - bifrost.releasePluginPipeline(pipeline) + finalizerOnce.Do(func() { + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }) } req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) } - // Execute request with retries + // Execute request with retries. Each handler invocation resolves the alias for the key + // selected by keyProvider on that attempt and mutates the worker-local request model. + // resolvedModel (captured by reference in postHookRunner) is updated accordingly. if IsStreamRequestType(req.RequestType) { - stream, bifrostError = executeRequestWithRetries(req.Context, config, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner) - }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) + stream, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + resolvedModel = k.Aliases.Resolve(originalModelRequested) + req.SetModel(resolvedModel) + return bifrost.handleProviderStreamRequest(provider, req, k, postHookRunner) + }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } else { - result, bifrostError = executeRequestWithRetries(req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) { - return bifrost.handleProviderRequest(provider, config, req, key, keys) - }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) + result, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { + resolvedModel = k.Aliases.Resolve(originalModelRequested) + req.SetModel(resolvedModel) + return bifrost.handleProviderRequest(provider, config, req, k, keys) + }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } // Release pipeline immediately for non-streaming requests only @@ -5104,14 +5633,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } if bifrostError != nil { - bifrostError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: model, - RequestType: req.RequestType, - RawRequest: bifrostError.ExtraFields.RawRequest, - RawResponse: bifrostError.ExtraFields.RawResponse, - KeyStatuses: bifrostError.ExtraFields.KeyStatuses, - } + bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) // Send error with context awareness to prevent deadlock select { @@ -5125,6 +5647,9 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") } } else { + if result != nil { + result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) + } if IsStreamRequestType(req.RequestType) { // Send stream with context awareness to prevent deadlock select { @@ -5168,12 +5693,34 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config } response.ListModelsResponse = listModelsResponse case schemas.TextCompletionRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.TextCompletionResponse = chatCompletionResponse.ToBifrostTextCompletionResponse() + break + } + } textCompletionResponse, bifrostError := provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest) if bifrostError != nil { return nil, bifrostError } response.TextCompletionResponse = textCompletionResponse case schemas.ChatCompletionRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest) + if bifrostError != nil { + return nil, bifrostError + } + response.ChatResponse = responsesResponse.ToBifrostChatResponse() + break + } + } chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest) if bifrostError != nil { return nil, bifrostError @@ -5206,6 +5753,16 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config } response.RerankResponse = rerankResponse case schemas.OCRRequest: + var customProviderConfig *schemas.CustomProviderConfig + if config != nil { + customProviderConfig = config.CustomProviderConfig + } + if bifrostError := providerUtils.CheckOperationAllowed(provider.GetProviderKey(), customProviderConfig, schemas.OCRRequest); bifrostError != nil { + if req.BifrostRequest.OCRRequest != nil { + bifrostError.ExtraFields.OriginalModelRequested = req.BifrostRequest.OCRRequest.Model + } + return nil, bifrostError + } ocrResponse, bifrostError := provider.OCR(req.Context, key, req.BifrostRequest.OCRRequest) if bifrostError != nil { return nil, bifrostError @@ -5223,6 +5780,7 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config if bifrostError != nil { return nil, bifrostError } + transcriptionResponse.BackfillParams(req.BifrostRequest.TranscriptionRequest) response.TranscriptionResponse = transcriptionResponse case schemas.ImageGenerationRequest: imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest) @@ -5416,9 +5974,10 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5429,8 +5988,20 @@ func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { + chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() + if chatRequest != nil { + return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), key, chatRequest) + } + } return provider.TextCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: + if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { + responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() + if responsesRequest != nil { + return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), key, responsesRequest) + } + } return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) case schemas.ResponsesStreamRequest: return provider.ResponsesStream(req.Context, postHookRunner, key, req.BifrostRequest.ResponsesRequest) @@ -5452,9 +6023,10 @@ func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, r Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider.GetProviderKey(), - ModelRequested: model, + RequestType: req.RequestType, + Provider: provider.GetProviderKey(), + OriginalModelRequested: model, + ResolvedModelUsed: model, }, } } @@ -5476,7 +6048,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ - Message: "MCP is not configured in this Bifrost instance", + Message: "mcp is not configured in this bifrost instance", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, @@ -5501,6 +6073,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR // Handle short-circuit with response (success case) if shortCircuit.Response != nil { finalMcpResp, bifrostErr := pipeline.RunMCPPostHooks(ctx, shortCircuit.Response, nil, preCount) + drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } @@ -5510,6 +6083,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR if shortCircuit.Error != nil { // Capture post-hook results to respect transformations or recovery finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) // Return post-hook error if present (post-hook may have transformed the error) if finalErr != nil { return nil, finalErr @@ -5552,6 +6126,11 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR RequestType: requestType, }, } + // Preserve MCPUserOAuthRequiredError for downstream detection in agent mode + var oauthErr *schemas.MCPUserOAuthRequiredError + if errors.As(err, &oauthErr) { + bifrostErr.ExtraFields.MCPAuthRequired = oauthErr + } } else if result == nil { bifrostErr = &schemas.BifrostError{ IsBifrostError: false, @@ -5569,6 +6148,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR // Run post-hooks finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, mcpResp, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) if finalErr != nil { return nil, finalErr @@ -5604,6 +6184,9 @@ func (bifrost *Bifrost) executeMCPToolWithHooks(ctx *schemas.BifrostContext, req resp, bifrostErr := bifrost.handleMCPToolExecution(ctx, request, requestType) if bifrostErr != nil { + if bifrostErr.ExtraFields.MCPAuthRequired != nil { + return nil, bifrostErr.ExtraFields.MCPAuthRequired + } return nil, fmt.Errorf("%s", GetErrorMessage(bifrostErr)) } return resp, nil @@ -5633,7 +6216,9 @@ func (p *PluginPipeline) RunLLMPreHooks(ctx *schemas.BifrostContext, req *schema } } - req, shortCircuit, err = plugin.PreLLMHook(ctx, req) + pluginCtx := ctx.WithPluginScope(&pluginName) + req, shortCircuit, err = plugin.PreLLMHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5673,8 +6258,10 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche if runFrom > len(p.llmPlugins) { runFrom = len(p.llmPlugins) } - // Detect streaming mode - if StreamStartTime is set, we're in a streaming context - isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil + requestType, _, _, _ := GetResponseFields(resp, bifrostErr) + // Realtime turns carry StreamStartTime for plugin latency/final-chunk context, + // but they are finalized as one completed turn, not chunk-by-chunk stream output. + isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error @@ -5684,8 +6271,17 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche p.logger.Debug("running post-hook for plugin %s", pluginName) if isStreaming { // For streaming: accumulate timing, don't create individual spans per chunk + // Lazily create cached scoped contexts on first chunk (reused across all chunks) + if p.streamScopedCtxs == nil { + p.streamScopedCtxs = make(map[string]*schemas.BifrostContext, len(p.llmPlugins)) + for _, pl := range p.llmPlugins { + name := pl.GetName() + p.streamScopedCtxs[name] = ctx.WithPluginScope(&name) + } + } + pluginCtx := p.streamScopedCtxs[pluginName] start := time.Now() - resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr) + resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) duration := time.Since(start) p.accumulatePluginTiming(pluginName, duration, err != nil) @@ -5702,7 +6298,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } - resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr) + pluginCtx := ctx.WithPluginScope(&pluginName) + resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) @@ -5756,7 +6354,9 @@ func (p *PluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schema } } - req, shortCircuit, err = plugin.PreMCPHook(ctx, req) + pluginCtx := ctx.WithPluginScope(&pluginName) + req, shortCircuit, err = plugin.PreMCPHook(pluginCtx, req) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5811,7 +6411,9 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s } } - mcpResp, bifrostErr, err = plugin.PostMCPHook(ctx, mcpResp, bifrostErr) + pluginCtx := ctx.WithPluginScope(&pluginName) + mcpResp, bifrostErr, err = plugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr) + pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { @@ -5837,7 +6439,11 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s return mcpResp, nil } -// resetPluginPipeline resets a PluginPipeline instance for reuse +// resetPluginPipeline resets a PluginPipeline instance for reuse. +// IMPORTANT: drainAndAttachPluginLogs must be called on the root BifrostContext +// BEFORE this method, because it calls ReleasePluginScope on cached scoped contexts +// which nils out their pluginLogs pointer. The drain reads from the shared store +// on the root context, so it must happen while the store is still referenced. func (p *PluginPipeline) resetPluginPipeline() { p.executedPreHooks = 0 p.preHookErrors = p.preHookErrors[:0] @@ -5848,6 +6454,25 @@ func (p *PluginPipeline) resetPluginPipeline() { clear(p.postHookTimings) } p.postHookPluginOrder = p.postHookPluginOrder[:0] + // Release cached scoped contexts for streaming + for _, scopedCtx := range p.streamScopedCtxs { + scopedCtx.ReleasePluginScope() + } + p.streamScopedCtxs = nil +} + +// drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext +// and attaches them to the trace for later retrieval by observability plugins. +func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) { + tracer, traceID, err := GetTracerFromContext(ctx) + if err != nil || tracer == nil || traceID == "" { + return + } + logs := ctx.DrainPluginLogs() + if len(logs) == 0 { + return + } + tracer.AttachPluginLogs(traceID, logs) } // accumulatePluginTiming accumulates timing for a plugin during streaming @@ -5939,7 +6564,9 @@ func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { return pipeline } -// releasePluginPipeline returns a PluginPipeline to the pool +// releasePluginPipeline returns a PluginPipeline to the pool. +// Caller must ensure drainAndAttachPluginLogs has already been called on the +// associated BifrostContext before calling this method. func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) { pipeline.resetPluginPipeline() bifrost.pluginPipelinePool.Put(pipeline) @@ -5984,6 +6611,47 @@ func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest) *ChannelMe return msg } +// drainQueueWithErrors drains all buffered messages from pq and sends each a +// "provider is shutting down" error. It must be called after all workers for +// the queue have exited (i.e. after wg.Wait()) to cover the TOCTOU window: +// a producer that passed isClosing() just before signalClosing fired can still +// win the `case pq.queue <- msg` branch in tryRequest, landing a message in +// the queue after the last worker's drain loop already exited via `default:`. +// Without this sweep, those callers block forever on <-msg.Response / <-msg.Err. +// +// Residual TOCTOU window (known limitation): this sweep runs exactly once via +// a non-blocking `select { default: }`. A producer that deposits a message +// after the sweep's `default:` branch exits has no worker and no sweep to drain +// it — the caller will block until its own context is cancelled. Fully closing +// this window requires a sender-side reference count (so the last producer can +// signal "queue is fully idle"), which is intentionally not implemented because +// it would add per-send atomic overhead on the hot path. +func (bifrost *Bifrost) drainQueueWithErrors(pq *ProviderQueue) { + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + select { + case r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{Message: "provider is shutting down"}, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + OriginalModelRequested: mod, + }, + }: + case <-r.Context.Done(): + // No time.After needed: r.Err is a buffered channel of size 1 freshly + // allocated per request, so the send always completes immediately unless + // the caller already cancelled. ctx.Done() is the only valid escape. + } + default: + return + } + } +} + // releaseChannelMessage returns a ChannelMessage and its channels to their respective pools. func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { // Put channels back in pools @@ -6108,17 +6776,21 @@ func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, provide // Filter keys for ListModels - only check if key has a value var supportedKeys []schemas.Key - for _, k := range keys { + for _, key := range keys { // Skip disabled keys (default enabled when nil) - if k.Enabled != nil && !*k.Enabled { + if key.Enabled != nil && !*key.Enabled { continue } - if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { - supportedKeys = append(supportedKeys, k) + if err := validateKey(baseProviderType, &key); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) + continue + } + if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, key) } } - bifrost.logger.Debug("[Bifrost] Provider %s: %d enabled keys found", providerKey, len(supportedKeys)) + bifrost.logger.Debug("[Bifrost] Provider %s: %d valid keys found", providerKey, len(supportedKeys)) if len(supportedKeys) == 0 { return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey) @@ -6161,17 +6833,21 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p continue } + if err := validateKey(baseProviderType, &k); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", k.Name, k.ID, providerKey, err.Error()) + continue + } + // Model filtering logic: // - If model is nil or empty → include all keys (no model filter) // - If model is specified: // - If model is in key.BlacklistedModels → exclude (wins over Models allow list) - // - If key.Models is empty → include key (supports all non-blacklisted models) + // - If key.Models is ["*"] → include key (supports all non-blacklisted models) + // - If key.Models is empty → exclude key (deny-by-default) // - If key.Models is non-empty → only include if model is in list + // Blacklist wins over allowlist if model != nil && *model != "" { - if len(k.BlacklistedModels) > 0 && slices.Contains(k.BlacklistedModels, *model) { - continue - } - if len(k.Models) > 0 && !slices.Contains(k.Models, *model) { + if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) { continue } } @@ -6201,28 +6877,39 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p return filteredKeys, nil } -// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. -// It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) (schemas.Key, error) { - // Check if key has been set in the context explicitly +// selectKeyFromProviderForModelWithPool returns the filtered pool of eligible keys for the given +// provider/model, along with a canRotate flag indicating whether key rotation across retries is +// permitted. Key selection (choosing which key to use) is deferred to executeRequestWithRetries +// via the keyProvider closure built by the caller. +// +// canRotate=false is returned for cases where the caller must always use the same key: +// - DirectKey (caller-supplied key bypasses all selection) +// - SkipKeySelection (provider allows keyless requests; empty slice returned) +// - Explicit BifrostContextKeyAPIKeyID / APIKeyName (user pinned a specific key) +// - Session stickiness (key persisted in KV store for the session lifetime) +// - Single-key pool (only one eligible key — rotation is a no-op, KV write skipped) +// +// canRotate=true is returned when there are two or more eligible keys and no pinning +// or stickiness constraint is in effect. +func (bifrost *Bifrost) selectKeyFromProviderForModelWithPool(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) ([]schemas.Key, bool, error) { + // DirectKey: caller supplied a key directly — no pool, no rotation. if ctx != nil { - key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) - if ok { - return key, nil + if key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); ok { + return []schemas.Key{key}, false, nil } } - // Check if key skipping is allowed + // SkipKeySelection: provider allows keyless requests — return empty pool, no rotation. if skipKeySelection, ok := ctx.Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) { - return schemas.Key{}, nil + return []schemas.Key{}, false, nil } + // Get keys for provider keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) if err != nil { - return schemas.Key{}, err + return nil, false, err } - // Check if no keys found if len(keys) == 0 { - return schemas.Key{}, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model) + return nil, false, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model) } // For batch API operations, filter keys to only include those with UseForBatchAPI enabled @@ -6234,7 +6921,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex } } if len(batchEnabledKeys) == 0 { - return schemas.Key{}, fmt.Errorf("no config found for batch APIs. Please enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey) + return nil, false, fmt.Errorf("no config found for batch apis; enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey) } keys = batchEnabledKeys } @@ -6248,112 +6935,89 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex skipModelCheck := (model == "" && (isFileRequestType(requestType) || isBatchRequestType(requestType) || isContainerRequestType(requestType) || isModellessVideoRequestType(requestType) || isPassthroughRequestType(requestType))) || requestType == schemas.ListModelsRequest if skipModelCheck { // When skipping model check: just verify keys are enabled and have values - for _, k := range keys { + for _, key := range keys { // Skip disabled keys - if k.Enabled != nil && !*k.Enabled { + if key.Enabled != nil && !*key.Enabled { continue } - if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { - supportedKeys = append(supportedKeys, k) + if err := validateKey(baseProviderType, &key); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) + continue + } + if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { + supportedKeys = append(supportedKeys, key) } } } else { - // When NOT skipping model check: do full model/deployment filtering + // When NOT skipping model check: do full model filtering for _, key := range keys { // Skip disabled keys if key.Enabled != nil && !*key.Enabled { continue } - hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) - var modelSupported bool - if len(key.BlacklistedModels) > 0 && slices.Contains(key.BlacklistedModels, model) { - modelSupported = false - } else { - modelSupported = (len(key.Models) == 0 && hasValue) || (slices.Contains(key.Models, model) && hasValue) + if err := validateKey(baseProviderType, &key); err != nil { + bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) + continue } - // Additional deployment checks for Azure, Bedrock and Vertex - deploymentSupported := true - if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil { - // For Azure, check if deployment exists for this model - if len(key.AzureKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.AzureKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Bedrock && key.BedrockKeyConfig != nil { - // For Bedrock, check if deployment exists for this model - if len(key.BedrockKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.BedrockKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Vertex && key.VertexKeyConfig != nil { - // For Vertex, check if deployment exists for this model - if len(key.VertexKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.VertexKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.Replicate && key.ReplicateKeyConfig != nil { - // For Replicate, check if deployment exists for this model - if len(key.ReplicateKeyConfig.Deployments) > 0 { - _, deploymentSupported = key.ReplicateKeyConfig.Deployments[model] - } - } else if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { - // For VLLM, check if model name matches the key's configured model + hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) + // ["*"] = allow all models; [] = deny all; specific list = allow only listed + // NOTE: Model filtering uses the original requested model (which may be an alias). + // key.Models and key.BlacklistedModels must therefore be expressed in alias keys. + // The provider-specific identifier is resolved later in the handler closure via key.Aliases.Resolve(model). + modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model) + if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { if key.VLLMKeyConfig.ModelName != "" { - deploymentSupported = (key.VLLMKeyConfig.ModelName == model) + modelSupported = modelSupported && (key.VLLMKeyConfig.ModelName == model) } } - - if modelSupported && deploymentSupported { + if modelSupported { supportedKeys = append(supportedKeys, key) } } } if len(supportedKeys) == 0 { - if baseProviderType == schemas.Azure || baseProviderType == schemas.Bedrock || baseProviderType == schemas.Vertex || baseProviderType == schemas.Replicate || baseProviderType == schemas.VLLM { - return schemas.Key{}, fmt.Errorf("no keys found that support model/deployment: %s", model) - } - return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) + return nil, false, fmt.Errorf("no keys found that support model: %s", model) } - // Key ID takes priority over key name when both are present + // Explicit key ID takes priority over key name — pin to that key, no rotation. if ctx != nil { if keyID, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyID).(string); ok { if keyID = strings.TrimSpace(keyID); keyID != "" { for _, key := range supportedKeys { if key.ID == keyID { - return key, nil + return []schemas.Key{key}, false, nil } } - return schemas.Key{}, fmt.Errorf("no supported key found with id %q for provider: %v and model: %s", keyID, providerKey, model) + return nil, false, fmt.Errorf("no supported key found with id %q for provider: %v and model: %s", keyID, providerKey, model) } } if keyName, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyName).(string); ok { if keyName = strings.TrimSpace(keyName); keyName != "" { for _, key := range supportedKeys { if key.Name == keyName { - return key, nil + return []schemas.Key{key}, false, nil } } - return schemas.Key{}, fmt.Errorf("no supported key found with name %q for provider: %v and model: %s", keyName, providerKey, model) + return nil, false, fmt.Errorf("no supported key found with name %q for provider: %v and model: %s", keyName, providerKey, model) } } } + // Single key: no rotation possible, skip session stickiness (no KV write needed). if len(supportedKeys) == 1 { - return supportedKeys[0], nil + return []schemas.Key{supportedKeys[0]}, false, nil } - // Session stickiness: on the first request for a session ID, the randomly - // selected key is persisted in the KV store. Subsequent requests reuse it as - // long as the key remains valid. The sticky-key lookup/selection in this block - // occurs before executeRequestWithRetries, so the same sticky key is - // intentionally applied for the entire session including all retry attempts— - // the selected key is persisted in KV and reused across retries rather than - // re-selected on each attempt. + // Session stickiness: on the first request for a session ID, the randomly selected key is + // persisted in the KV store. Subsequent requests reuse it for the session lifetime. The sticky + // key is intentionally kept fixed across all retry attempts — return it as a single-element + // pool with canRotate=false so rate-limit retries also stay on the same key. sessionID := "" if ctx != nil { if id, ok := ctx.Value(schemas.BifrostContextKeySessionID).(string); ok && id != "" { sessionID = id } } - fallbackIndex := 0 if ctx != nil { fallbackIndex, _ = ctx.Value(schemas.BifrostContextKeyFallbackIndex).(int) @@ -6367,58 +7031,46 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex ttl = schemas.DefaultSessionStickyTTL } - // Try to retrieve existing cached key if cachedKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found { - // Refresh TTL so active sessions do not expire. - err := bifrost.kvStore.SetWithTTL(kvKey, cachedKey.ID, ttl) - if err != nil { + if err := bifrost.kvStore.SetWithTTL(kvKey, cachedKey.ID, ttl); err != nil { bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, cachedKey.ID, err.Error()) } - return cachedKey, nil + return []schemas.Key{cachedKey}, false, nil } else if stale { if _, err := bifrost.kvStore.Delete(kvKey); err != nil { bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error()) } } - // No cached key found (or stale entry deleted), select a new one selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model) if err != nil { - return schemas.Key{}, err + return nil, false, err } - // Atomically set the key only if not already set (first-write-wins) wasSet, err := bifrost.kvStore.SetNXWithTTL(kvKey, selectedKey.ID, ttl) if err != nil { bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, selectedKey.ID, err.Error()) - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - if wasSet { - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - // Another concurrent request won the race, re-read the current key + // Another concurrent request won the race — re-read the persisted key. if currentKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found { - return currentKey, nil + return []schemas.Key{currentKey}, false, nil } else if stale { if _, err := bifrost.kvStore.Delete(kvKey); err != nil { bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error()) } - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - // Fallback: if we can't read the current key, use what we selected - // (shouldn't happen in normal operation, but defensive) - return selectedKey, nil + return []schemas.Key{selectedKey}, false, nil } - selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model) - if err != nil { - return schemas.Key{}, err - } - - return selectedKey, nil + // Normal case: return the full filtered pool with rotation enabled. + return supportedKeys, true, nil } // getCachedKeyFromStore retrieves a key ID from the KV store and looks it up in supportedKeys. @@ -6455,34 +7107,6 @@ func getCachedKeyFromStore(kvStore schemas.KVStore, kvKey string, supportedKeys return schemas.Key{}, false, false } -func WeightedRandomKeySelector(ctx *schemas.BifrostContext, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { - // Use a weighted random selection based on key weights - totalWeight := 0 - for _, key := range keys { - totalWeight += int(key.Weight * 100) // Convert float to int for better performance - } - - // If all keys have zero weight, fall back to uniform random selection - if totalWeight == 0 { - return keys[rand.Intn(len(keys))], nil - } - - // Use global thread-safe random (Go 1.20+) - no allocation, no syscall - randomValue := rand.Intn(totalWeight) - - // Select key based on weight - currentWeight := 0 - for _, key := range keys { - currentWeight += int(key.Weight * 100) - if randomValue < currentWeight { - return key, nil - } - } - - // Fallback to first key if something goes wrong - return keys[0], nil -} - // Shutdown gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. func (bifrost *Bifrost) Shutdown() { @@ -6491,15 +7115,12 @@ func (bifrost *Bifrost) Shutdown() { if bifrost.ctx.Err() == nil && bifrost.cancel != nil { bifrost.cancel() } - // ALWAYS close all provider queues to signal workers to stop, - // even if context was already cancelled. This prevents goroutine leaks. - // Use the ProviderQueue lifecycle: signal closing, then close the queue + // Signal all provider queues to close. Workers exit via pq.done; + // we never close pq.queue to avoid "send on closed channel" panics in + // producers that are concurrently in tryRequest. bifrost.requestQueues.Range(func(key, value interface{}) bool { pq := value.(*ProviderQueue) - // Signal closing to producers (uses sync.Once internally) pq.signalClosing() - // Close the queue to signal workers (uses sync.Once internally) - pq.closeQueue() return true }) @@ -6510,6 +7131,12 @@ func (bifrost *Bifrost) Shutdown() { return true }) + // Final drain sweep — same reasoning as RemoveProvider's Step 3b. + bifrost.requestQueues.Range(func(key, value interface{}) bool { + bifrost.drainQueueWithErrors(value.(*ProviderQueue)) + return true + }) + // Cleanup MCP manager if bifrost.MCPManager != nil { err := bifrost.MCPManager.Cleanup() diff --git a/core/bifrost_test.go b/core/bifrost_test.go index cb22f5e359..08490e7dd0 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -3,8 +3,10 @@ package bifrost import ( "context" "fmt" + "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -57,7 +59,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { // Test immediate success t.Run("ImmediateSuccess", func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ return "success", nil } @@ -66,6 +68,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -87,7 +90,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { // Test success after retries t.Run("SuccessAfterRetries", func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ if callCount <= 2 { // First two calls fail with retryable error @@ -101,6 +104,7 @@ func TestExecuteRequestWithRetries_SuccessScenarios(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -128,7 +132,7 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { logger := NewDefaultLogger(schemas.LogLevelError) t.Run("ExceedsMaxRetries", func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ // Always fail with retryable error return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) @@ -138,6 +142,7 @@ func TestExecuteRequestWithRetries_RetryLimits(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -194,7 +199,7 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ return "", tc.error } @@ -203,6 +208,7 @@ func TestExecuteRequestWithRetries_NonRetryableErrors(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -270,7 +276,7 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ return "", tc.error } @@ -279,6 +285,7 @@ func TestExecuteRequestWithRetries_RetryableConditions(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -509,7 +516,7 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { var attemptCounts []int callCount := 0 - handler := func() (string, *schemas.BifrostError) { + handler := func(_ schemas.Key) (string, *schemas.BifrostError) { callCount++ attemptCounts = append(attemptCounts, callCount) @@ -526,6 +533,7 @@ func TestExecuteRequestWithRetries_LoggingAndCounting(t *testing.T) { ctx, config, handler, + nil, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", @@ -605,8 +613,8 @@ func TestHandleProviderRequest_OCROperationNotAllowed(t *testing.T) { if err.ExtraFields.RequestType != schemas.OCRRequest { t.Fatalf("expected OCR request type, got %q", err.ExtraFields.RequestType) } - if err.ExtraFields.ModelRequested != "custom-mistral/mistral-ocr-latest" { - t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.ModelRequested) + if err.ExtraFields.OriginalModelRequested != "custom-mistral/mistral-ocr-latest" { + t.Fatalf("expected model to be preserved, got %q", err.ExtraFields.OriginalModelRequested) } } @@ -811,15 +819,15 @@ func (m *mockKVStore) Delete(key string) (bool, error) { return false, nil } -// Test selectKeyFromProviderForModel with session stickiness +// Test selectKeyFromProviderForModelWithPool with session stickiness func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { kvStore := newMockKVStore() account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) // Use 2 keys so we hit the keySelector path (single key returns early) account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ - {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Weight: 1}, - {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Weight: 1}, + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, }) var keySelectorCalls int @@ -842,13 +850,16 @@ func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) bfCtx.SetValue(schemas.BifrostContextKeySessionID, "sess-123") - // First call: cache miss, keySelector runs, key stored - key1, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + // First call: cache miss, keySelector runs, key stored; returns single-element pool (canRotate=false) + keys1, canRotate1, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { - t.Fatalf("first selectKeyFromProviderForModel: %v", err) + t.Fatalf("first selectKeyFromProviderForModelWithPool: %v", err) } - if key1.ID != "key-a" { - t.Errorf("first call: expected key-a, got %s", key1.ID) + if canRotate1 { + t.Error("first call: canRotate should be false for session-sticky request") + } + if len(keys1) != 1 || keys1[0].ID != "key-a" { + t.Errorf("first call: expected [key-a], got %v", keys1) } if keySelectorCalls != 1 { t.Errorf("first call: expected 1 keySelector call, got %d", keySelectorCalls) @@ -861,26 +872,29 @@ func TestSelectKeyFromProviderForModel_SessionStickiness(t *testing.T) { } // Second call: cache hit, same key returned, keySelector NOT called - key2, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + keys2, canRotate2, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { - t.Fatalf("second selectKeyFromProviderForModel: %v", err) + t.Fatalf("second selectKeyFromProviderForModelWithPool: %v", err) + } + if canRotate2 { + t.Error("second call: canRotate should be false for session-sticky request") } - if key2.ID != "key-a" { - t.Errorf("second call: expected key-a (sticky), got %s", key2.ID) + if len(keys2) != 1 || keys2[0].ID != "key-a" { + t.Errorf("second call: expected [key-a] (sticky), got %v", keys2) } if keySelectorCalls != 1 { t.Errorf("second call: keySelector should not run (cache hit), got %d calls", keySelectorCalls) } } -// Test selectKeyFromProviderForModel - no stickiness when session ID absent +// Test selectKeyFromProviderForModelWithPool - no stickiness when session ID absent func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T) { kvStore := newMockKVStore() account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ - {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Weight: 1}, - {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Weight: 1}, + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, }) var keySelectorCalls int @@ -901,19 +915,22 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T } bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) - // No session ID set + // No session ID set — pool is returned with canRotate=true; keySelector is called each time. for i := 0; i < 2; i++ { - key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + pool, canRotate, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err != nil { - t.Fatalf("selectKeyFromProviderForModel call %d: %v", i+1, err) + t.Fatalf("selectKeyFromProviderForModelWithPool call %d: %v", i+1, err) + } + if !canRotate { + t.Fatalf("call %d: canRotate should be true without a session id", i+1) } - if key.ID != "key-a" { - t.Fatalf("call %d: expected key-a, got %s", i+1, key.ID) + if len(pool) == 0 { + t.Fatalf("call %d: expected non-empty pool", i+1) } } - if keySelectorCalls != 2 { - t.Errorf("expected 2 keySelector calls without a session id, got %d", keySelectorCalls) + if keySelectorCalls != 0 { + t.Errorf("expected 0 keySelector calls from pool building (no session id), got %d", keySelectorCalls) } // KVStore should not have a sticky entry for an empty session id if _, err := kvStore.Get(buildSessionKey(schemas.OpenAI, "", "gpt-4")); err == nil { @@ -921,6 +938,82 @@ func TestSelectKeyFromProviderForModel_NoStickinessWithoutSessionID(t *testing.T } } +// TestSelectKeyFromProviderForModel_SessionStickinessNoRotation verifies that when a session ID +// is present, rate-limit retries reuse the sticky key rather than rotating to another key. +func TestSelectKeyFromProviderForModel_SessionStickinessNoRotation(t *testing.T) { + kvStore := newMockKVStore() + account := NewMockAccount() + account.AddProvider(schemas.OpenAI, 5, 1000) + account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ + {ID: "key-a", Name: "Key A", Value: *schemas.NewEnvVar("sk-a"), Models: schemas.WhiteList{"*"}, Weight: 1}, + {ID: "key-b", Name: "Key B", Value: *schemas.NewEnvVar("sk-b"), Models: schemas.WhiteList{"*"}, Weight: 1}, + }) + + deterministicSelector := func(ctx *schemas.BifrostContext, keys []schemas.Key, _ schemas.ModelProvider, _ string) (schemas.Key, error) { + return keys[0], nil // always picks key-a when pool includes it + } + + ctx := context.Background() + bifrost, err := Init(ctx, schemas.BifrostConfig{ + Account: account, + Logger: NewDefaultLogger(schemas.LogLevelError), + KVStore: kvStore, + KeySelector: deterministicSelector, + }) + if err != nil { + t.Fatalf("Init failed: %v", err) + } + + bfCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bfCtx.SetValue(schemas.BifrostContextKeySessionID, "sess-sticky") + bfCtx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) + + config := createTestConfig(3, 0, 0) + logger := NewDefaultLogger(schemas.LogLevelError) + + // Build keyProvider the same way requestWorker does. + pool, canRotate, poolErr := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + if poolErr != nil { + t.Fatalf("pool build failed: %v", poolErr) + } + if canRotate { + t.Fatal("expected canRotate=false for session-sticky request") + } + if len(pool) != 1 || pool[0].ID != "key-a" { + t.Fatalf("expected sticky pool=[key-a], got %v", pool) + } + + fixedKey := pool[0] + keyProvider := func(_ map[string]bool) (schemas.Key, error) { return fixedKey, nil } + + // Simulate 3 rate-limit failures then success; all attempts must use key-a. + var usedKeyIDs []string + callCount := 0 + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + usedKeyIDs = append(usedKeyIDs, k.ID) + callCount++ + if callCount <= 3 { + return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) + } + return "ok", nil + } + + result, retryErr := executeRequestWithRetries(bfCtx, config, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if retryErr != nil { + t.Fatalf("expected success, got error: %v", retryErr) + } + if result != "ok" { + t.Errorf("expected 'ok', got %s", result) + } + for i, id := range usedKeyIDs { + if id != "key-a" { + t.Errorf("attempt %d: expected sticky key-a, got %s (full sequence: %v)", i, id, usedKeyIDs) + } + } +} + func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { account := NewMockAccount() account.AddProvider(schemas.OpenAI, 5, 1000) @@ -939,7 +1032,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, }) - _, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + _, _, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err == nil { t.Fatal("expected error when model is only blacklisted") } @@ -956,7 +1049,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { BlacklistedModels: []string{"gpt-4"}, }, }) - _, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + _, _, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) if err == nil { t.Fatal("expected error when model is both allowed and blacklisted") } @@ -965,14 +1058,200 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) { t.Run("second key used when first blacklists", func(t *testing.T) { account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{ {ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}}, - {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1}, + {ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1, Models: []string{"*"}}, }) - key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + pool, canRotate, err := bifrost.selectKeyFromProviderForModelWithPool(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // After filtering, only k2 remains — single key returns canRotate=false. + if canRotate { + t.Fatal("expected canRotate=false for single-key pool after filtering") + } + if len(pool) != 1 || pool[0].ID != "k2" { + t.Fatalf("expected pool=[k2], got %v", pool) + } + }) +} + +// Test key rotation in executeRequestWithRetries on rate-limit errors +func TestExecuteRequestWithRetries_KeyRotation(t *testing.T) { + config := createTestConfig(3, 0, 0) + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) + logger := NewDefaultLogger(schemas.LogLevelError) + + keys := []schemas.Key{ + {ID: "k1", Name: "K1"}, + {ID: "k2", Name: "K2"}, + {ID: "k3", Name: "K3"}, + } + + t.Run("RotatesKeyOnRateLimitRetry", func(t *testing.T) { + var selectedKeyIDs []string + keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) { + for _, k := range keys { + if !usedKeyIDs[k.ID] { + return k, nil + } + } + // Fresh round + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + return keys[0], nil + } + + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + selectedKeyIDs = append(selectedKeyIDs, k.ID) + // First two calls rate-limit, third succeeds + if len(selectedKeyIDs) <= 2 { + return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) + } + return "success", nil + } + + result, err := executeRequestWithRetries(ctx, config, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if result != "success" { + t.Errorf("expected 'success', got %s", result) + } + if len(selectedKeyIDs) != 3 { + t.Fatalf("expected 3 attempts, got %d", len(selectedKeyIDs)) + } + // Each attempt should use a different key + seen := map[string]struct{}{} + for _, id := range selectedKeyIDs { + seen[id] = struct{}{} + } + if len(seen) != len(selectedKeyIDs) { + t.Errorf("expected distinct keys per rate-limit retry, got %v", selectedKeyIDs) + } + }) + + t.Run("SameKeyOnNetworkError", func(t *testing.T) { + var selectedKeyIDs []string + keyProviderCalls := 0 + keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) { + keyProviderCalls++ + for _, k := range keys { + if !usedKeyIDs[k.ID] { + return k, nil + } + } + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + return keys[0], nil + } + + callCount := 0 + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + selectedKeyIDs = append(selectedKeyIDs, k.ID) + callCount++ + if callCount <= 2 { + return "", createBifrostError(schemas.ErrProviderDoRequest, nil, nil, false) + } + return "success", nil + } + + result, err := executeRequestWithRetries(ctx, config, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if err != nil { + t.Fatalf("expected success, got error: %v", err) + } + if result != "success" { + t.Errorf("expected 'success', got %s", result) + } + if len(selectedKeyIDs) != 3 { + t.Fatalf("expected 3 attempts, got %d", len(selectedKeyIDs)) + } + if keyProviderCalls != 1 { + t.Fatalf("expected keyProvider to be called once for network retries, got %d", keyProviderCalls) + } + // All attempts should use the same key (network error = same key) + for i := 1; i < len(selectedKeyIDs); i++ { + if selectedKeyIDs[i] != selectedKeyIDs[0] { + t.Errorf("expected same key for all network-error retries, got %v", selectedKeyIDs) + } + } + }) + + t.Run("CyclesFreshRoundWhenPoolExhausted", func(t *testing.T) { + var selectedKeyIDs []string + // 3 keys, 6 retries — should cycle through all 3 keys twice + config6 := createTestConfig(5, 0, 0) // 5 retries = 6 total attempts + keyProvider := func(usedKeyIDs map[string]bool) (schemas.Key, error) { + available := make([]schemas.Key, 0) + for _, k := range keys { + if !usedKeyIDs[k.ID] { + available = append(available, k) + } + } + if len(available) == 0 { + for id := range usedKeyIDs { + delete(usedKeyIDs, id) + } + available = keys + } + return available[0], nil + } + + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + selectedKeyIDs = append(selectedKeyIDs, k.ID) + return "", createBifrostError("rate limit exceeded", Ptr(429), nil, false) + } + + executeRequestWithRetries(ctx, config6, handler, keyProvider, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + + if len(selectedKeyIDs) != 6 { + t.Fatalf("expected 6 attempts (1 initial + 5 retries), got %d", len(selectedKeyIDs)) + } + // First cycle: k1, k2, k3; second cycle: k1, k2, k3 + expected := []string{"k1", "k2", "k3", "k1", "k2", "k3"} + for i, id := range selectedKeyIDs { + if id != expected[i] { + t.Errorf("attempt %d: expected key %s, got %s (full sequence: %v)", i, expected[i], id, selectedKeyIDs) + } + } + }) + + t.Run("NilKeyProviderUsesZeroKey", func(t *testing.T) { + cleanCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + cleanCtx.SetValue(schemas.BifrostContextKeyTracer, &schemas.NoOpTracer{}) + + var receivedKey schemas.Key + handler := func(k schemas.Key) (string, *schemas.BifrostError) { + receivedKey = k + return "ok", nil + } + + result, err := executeRequestWithRetries(cleanCtx, config, handler, nil, + schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", nil, logger) + if err != nil { t.Fatalf("unexpected error: %v", err) } - if key.ID != "k2" { - t.Fatalf("expected k2, got %s", key.ID) + if result != "ok" { + t.Errorf("expected 'ok', got %s", result) + } + if receivedKey.ID != "" { + t.Errorf("expected zero Key when keyProvider is nil, got ID=%s", receivedKey.ID) + } + if trail, ok := cleanCtx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 { + t.Fatalf("expected no attempt trail for nil keyProvider, got %v", trail) + } + if selectedID, _ := cleanCtx.Value(schemas.BifrostContextKeySelectedKeyID).(string); selectedID != "" { + t.Fatalf("expected empty selected key id, got %q", selectedID) + } + if selectedName, _ := cleanCtx.Value(schemas.BifrostContextKeySelectedKeyName).(string); selectedName != "" { + t.Fatalf("expected empty selected key name, got %q", selectedName) } }) } @@ -1300,3 +1579,998 @@ func TestUpdateProvider_ProviderSliceIntegrity(t *testing.T) { } }) } + +// TestProviderQueue_SendOnClosedChannel_Race demonstrates the TOCTOU race that +// caused the "send on closed channel" production panic in the OLD code. +// +// The old code called close(pq.queue) during provider shutdown. The sequence: +// 1. Producer calls isClosing() → false (queue is still open) +// 2. Concurrently: shutdown calls signalClosing() then close(pq.queue) +// 3. Producer enters select { case pq.queue <- msg: ... case <-pq.done: ... } +// → PANIC: Go's selectgo iterates cases in a randomised pollorder. When the +// closed-channel send case is checked first, it immediately panics via +// goto sclose — before it can reach the done case. +// The case <-pq.done: guard only saves you when done happens to be checked +// first in that random ordering (≈50 % of the time with two cases). +// +// THE FIX: pq.queue is never closed. See the ProviderQueue struct comment for +// the full explanation. This test is kept as a proof-of-concept showing why +// closing pq.queue is unsafe; the fix is validated by TestProviderQueue_NoPanicWithoutCloseQueue. +// +// We run many iterations so that the panic is statistically certain to surface +// at least once, confirming the hypothesis. +func TestProviderQueue_SendOnClosedChannel_Race(t *testing.T) { + // With two select cases each iteration has a ~50 % chance of panicking. + // The probability of never panicking in 200 iterations is (0.5)^200 ≈ 0. + const iterations = 200 + panicCount := 0 + + for i := 0; i < iterations; i++ { + func() { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Synchronization barriers to force the exact race interleaving. + passedIsClosingCheck := make(chan struct{}) + queueClosed := make(chan struct{}) + + var panicked bool + var wg sync.WaitGroup + wg.Add(1) + + // Producer — mirrors the hot path in tryRequest. + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil && fmt.Sprint(r) == "send on closed channel" { + panicked = true + } + }() + + // Step 1: isClosing() passes — queue is open. + if pq.isClosing() { + return + } + + // Signal: past the isClosing() gate. + close(passedIsClosingCheck) + + // Wait for the queue to be closed. This represents the real work + // tryRequest does between the isClosing() check and the select + // (MCP setup, tracer lookup, plugin pipeline acquisition). + <-queueClosed + + // Step 2: enter the exact select guard used in production. + // pq.queue is closed AND pq.done is closed. + // When selectgo picks the send case first in its random pollorder + // it hits goto sclose and panics — the done case cannot save it. + msg := &ChannelMessage{} + select { + case pq.queue <- msg: // panics ~50 % of iterations + case <-pq.done: // selected the other ~50 % + } + }() + + // Closer — mirrors UpdateProvider / RemoveProvider. + go func() { + <-passedIsClosingCheck + pq.signalClosing() // closes done, sets closing = 1 + close(pq.queue) + close(queueClosed) // release the producer into the select + }() + + wg.Wait() + if panicked { + panicCount++ + } + }() + } + + if panicCount == 0 { + t.Fatalf("expected at least one 'send on closed channel' panic across %d iterations, got none", iterations) + } + t.Logf("confirmed: panic triggered in %d / %d iterations — hypothesis is correct", panicCount, iterations) +} + +// ============================================================================= +// ProviderQueue Unit Tests +// +// These tests exercise the ProviderQueue lifecycle in isolation — no full +// Bifrost instance required. They validate the core safety invariants that +// prevent the "send on closed channel" panic. +// ============================================================================= + +// newTestChannelMessage creates a minimal ChannelMessage suitable for drain tests. +// The Err channel is buffered (size 1) so the worker can send without blocking. +func newTestChannelMessage(ctx *schemas.BifrostContext) *ChannelMessage { + return &ChannelMessage{ + BifrostRequest: schemas.BifrostRequest{ + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + }, + }, + Context: ctx, + Response: make(chan *schemas.BifrostResponse, 1), + Err: make(chan schemas.BifrostError, 1), + } +} + +// TestProviderQueue_IsClosingStateTransition verifies the atomic state flag: +// isClosing() must return false before signalClosing() and true after. +func TestProviderQueue_IsClosingStateTransition(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + if pq.isClosing() { + t.Fatal("isClosing() must be false before signalClosing() is called") + } + + pq.signalClosing() + + if !pq.isClosing() { + t.Fatal("isClosing() must be true after signalClosing() is called") + } + + // done channel must also be closed + select { + case <-pq.done: + // correct: done is closed + default: + t.Fatal("pq.done must be closed after signalClosing()") + } + + // queue channel must remain OPEN — this is the core of the fix + // (sending should not panic even though done is closed) + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + select { + case pq.queue <- &ChannelMessage{}: + case <-pq.done: // done is closed so this is always ready — no panic + } + }() + if panicked { + t.Fatal("queue channel must stay open after signalClosing() — sending to it must not panic") + } +} + +// TestProviderQueue_SignalOnceIdempotent verifies that calling signalClosing() +// multiple times is safe. sync.Once ensures done is only closed once and the +// atomic store only happens once — no "close of closed channel" panic. +func TestProviderQueue_SignalOnceIdempotent(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic from multiple signalClosing() calls: %v", r) + } + }() + + pq.signalClosing() + pq.signalClosing() + pq.signalClosing() + + if !pq.isClosing() { + t.Fatal("isClosing() must be true after multiple signalClosing() calls") + } +} + +// TestProviderQueue_WorkerExitsViaDone verifies that a worker running the +// fixed select loop exits cleanly after signalClosing() without closeQueue(). +// Before the fix, workers used `for req := range pq.queue` which required +// the channel to be closed. After the fix, done is the exit signal. +func TestProviderQueue_WorkerExitsViaDone(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + workerExited := make(chan struct{}) + + // Minimal worker loop — mirrors the exact select pattern in requestWorker + go func() { + defer close(workerExited) + for { + select { + case r, ok := <-pq.queue: + if !ok { + return + } + _ = r // process (no-op in this test) + case <-pq.done: + // Drain remaining buffered items (queue is empty here) + for { + select { + case <-pq.queue: + default: + return + } + } + } + } + }() + + // Worker is now blocked on the select. Signal shutdown WITHOUT closing queue. + pq.signalClosing() + + select { + case <-workerExited: + // correct: worker exited via done + case <-time.After(2 * time.Second): + t.Fatal("worker did not exit after signalClosing() — it may be stuck on range over unclosed channel") + } +} + +// TestProviderQueue_WorkerDrainSendsErrors verifies the drain behaviour when +// done fires while items are still buffered: every buffered ChannelMessage must +// receive a "provider is shutting down" error on its Err channel. No client +// should be left blocked waiting for a response that will never come. +// +// This test exercises the drain path directly — same code as requestWorker's +// case <-pq.done: branch — to avoid a non-deterministic select race between the +// normal processing path and the done path. +func TestProviderQueue_WorkerDrainSendsErrors(t *testing.T) { + const numBuffered = 5 + + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numBuffered+2), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + // Pre-fill queue — simulates requests buffered when done fires + msgs := make([]*ChannelMessage, numBuffered) + for i := 0; i < numBuffered; i++ { + msgs[i] = newTestChannelMessage(ctx) + pq.queue <- msgs[i] + } + + // Signal closing: done is now closed + pq.signalClosing() + + // Execute the drain path synchronously — exactly what requestWorker does in + // the case <-pq.done: branch. This is deterministic: we know done is closed + // and the queue has numBuffered items. + <-pq.done // fires immediately since signalClosing was already called +drainLoop: + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is shutting down", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + OriginalModelRequested: mod, + }, + } + default: + break drainLoop + } + } + + // Verify every message received a shutdown error + for i, msg := range msgs { + select { + case bifrostErr := <-msg.Err: + if bifrostErr.Error == nil { + t.Errorf("message %d: received nil Error field", i) + continue + } + if bifrostErr.Error.Message != "provider is shutting down" { + t.Errorf("message %d: expected 'provider is shutting down', got %q", + i, bifrostErr.Error.Message) + } + if bifrostErr.ExtraFields.Provider != schemas.OpenAI { + t.Errorf("message %d: expected provider %s, got %s", + i, schemas.OpenAI, bifrostErr.ExtraFields.Provider) + } + if bifrostErr.ExtraFields.RequestType != schemas.ChatCompletionRequest { + t.Errorf("message %d: expected requestType %v, got %v", + i, schemas.ChatCompletionRequest, bifrostErr.ExtraFields.RequestType) + } + default: + t.Errorf("message %d: no error received — client would be left hanging indefinitely", i) + } + } +} + +// TestProviderQueue_NoPanicWithoutCloseQueue verifies that the fixed hot path +// — select { case pq.queue <- msg | case <-pq.done } — never panics when +// signalClosing() fires but the queue channel is NOT closed. +// +// This is the direct inverse of TestProviderQueue_SendOnClosedChannel_Race: +// that test proves the old code panics ~50% of the time; this test proves +// the fixed code panics 0% of the time. +func TestProviderQueue_NoPanicWithoutCloseQueue(t *testing.T) { + const iterations = 500 + + for i := 0; i < iterations; i++ { + func() { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + passedIsClosingCheck := make(chan struct{}) + shutdownDone := make(chan struct{}) + + var panicked bool + var wg sync.WaitGroup + wg.Add(1) + + // Producer: mirrors the tryRequest hot path after the fix. + // Passes isClosing(), waits for signalClosing, then sends. + // The queue channel is NEVER closed — only done is closed. + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + + if pq.isClosing() { + return + } + close(passedIsClosingCheck) + <-shutdownDone + + msg := &ChannelMessage{} + select { + case pq.queue <- msg: // queue is open → safe to send + case <-pq.done: // done is closed → selected immediately + } + }() + + // Closer: signal shutdown but never close the queue channel + go func() { + <-passedIsClosingCheck + pq.signalClosing() // closes done; does NOT close queue + close(shutdownDone) + }() + + wg.Wait() + + if panicked { + t.Errorf("iteration %d: unexpected panic — queue must not be closed in the fixed path", i) + } + }() + + if t.Failed() { + return + } + } + + t.Logf("confirmed: zero panics in %d iterations with the fix applied", iterations) +} + +// ============================================================================= +// UpdateProvider Lifecycle Tests +// +// These tests verify the three key invariants of the UpdateProvider fix: +// 1. New queue is stored BEFORE signalClosing fires (stale producers re-route) +// 2. Transfer happens BEFORE signalClosing (items go to new workers, not errored) +// 3. Concurrent producers + UpdateProvider produce zero panics +// ============================================================================= + +// TestUpdateProvider_StaleProducerReroutes verifies that a "stale producer" — +// a goroutine that fetched oldPq before UpdateProvider atomically replaced it — +// can transparently re-route to newPq when it later detects isClosing(). +// +// The re-routing logic in tryRequest is: +// +// if pq.isClosing() { +// if newPq, err := bifrost.getProviderQueue(provider); err == nil && newPq != pq { +// pq = newPq // transparent re-route +// } +// } +// +// This test exercises that exact sequence without a full Bifrost instance. +func TestUpdateProvider_StaleProducerReroutes(t *testing.T) { + var requestQueues sync.Map + provider := schemas.OpenAI + + oldPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + newPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Initial state: requestQueues holds oldPq + requestQueues.Store(provider, oldPq) + + // Stale producer: fetched its reference before UpdateProvider ran + stalePq := oldPq + + // Simulate UpdateProvider steps 2 + 4: + // Step 2: atomically replace — new producers now get newPq + requestQueues.Store(provider, newPq) + // Step 4: signal old closing — stale producers will detect this + oldPq.signalClosing() + + // --- Stale producer detects isClosing and attempts re-route --- + var reroutedPq *ProviderQueue + if stalePq.isClosing() { + if val, ok := requestQueues.Load(provider); ok { + candidate := val.(*ProviderQueue) + if candidate != stalePq { + reroutedPq = candidate + } + } + } + + if reroutedPq == nil { + t.Fatal("stale producer failed to re-route: re-route returned nil (check step ordering)") + } + if reroutedPq != newPq { + t.Fatal("stale producer re-routed to wrong queue: expected newPq") + } + if reroutedPq.isClosing() { + t.Fatal("re-routed queue is already closing — re-route is useless (newPq must be fresh)") + } + + // Verify: sending to re-routed queue succeeds without panic + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + msg := &ChannelMessage{} + select { + case reroutedPq.queue <- msg: + case <-reroutedPq.done: + t.Error("newPq.done fired — newPq should be open") + } + }() + if panicked { + t.Fatal("panic while sending to re-routed queue — queue must not be closed") + } +} + +// TestUpdateProvider_TransferOrdering verifies the ordering invariant: +// items are moved from oldPq to newPq BEFORE signalClosing(oldPq) is called. +// +// Observable consequence: during the entire transfer loop, oldPq.isClosing() +// must remain false. Only after transfer completes does signalClosing fire. +func TestUpdateProvider_TransferOrdering(t *testing.T) { + const numMessages = 8 + + oldPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numMessages+2), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + newPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numMessages+2), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Pre-fill oldPq — simulates buffered requests at the moment UpdateProvider runs + for i := 0; i < numMessages; i++ { + oldPq.queue <- &ChannelMessage{} + } + + // Invariant check before transfer begins + if oldPq.isClosing() { + t.Fatal("invariant violated: oldPq already closing before transfer begins") + } + + // Perform transfer, mirroring UpdateProvider step 3. + // Record whether isClosing() ever fired during the loop. + closingDuringTransfer := false + transferred := 0 + for { + select { + case msg := <-oldPq.queue: + if oldPq.isClosing() { + closingDuringTransfer = true + } + newPq.queue <- msg + transferred++ + default: + goto transferComplete + } + } +transferComplete: + + if closingDuringTransfer { + t.Error("invariant violated: oldPq was already closing during transfer — " + + "signalClosing must fire AFTER the transfer loop completes") + } + + // NOW signal closing, mirroring UpdateProvider step 4 + oldPq.signalClosing() + + if !oldPq.isClosing() { + t.Error("expected isClosing() == true after signalClosing()") + } + + // All messages must have moved to newPq + if transferred != numMessages { + t.Errorf("expected %d messages transferred, got %d", numMessages, transferred) + } + if len(newPq.queue) != numMessages { + t.Errorf("expected %d messages in newPq after transfer, got %d", numMessages, len(newPq.queue)) + } + if len(oldPq.queue) != 0 { + t.Errorf("expected 0 messages remaining in oldPq after transfer, got %d", len(oldPq.queue)) + } +} + +// TestUpdateProvider_NoPanicConcurrentAccess verifies that concurrent producers +// sending to a queue that is being replaced (UpdateProvider-style) never cause +// a "send on closed channel" panic. +// +// This test directly models the production scenario that triggered the bug: +// many goroutines continuously send to a ProviderQueue while UpdateProvider +// atomically swaps the queue and signals the old one closing. With the fix +// (queue channel is never closed), the select in producers is always safe. +func TestUpdateProvider_NoPanicConcurrentAccess(t *testing.T) { + const ( + numProducers = 10 + numUpdates = 30 + producerRunTime = 300 * time.Millisecond + ) + + var requestQueues sync.Map + provider := schemas.OpenAI + + makePq := func() *ProviderQueue { + return &ProviderQueue{ + queue: make(chan *ChannelMessage, 200), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + } + + initialPq := makePq() + requestQueues.Store(provider, initialPq) + + var panicCount int64 + var transferDropCount int64 + + stop := make(chan struct{}) + var producerWg sync.WaitGroup + + // Drainer: continuously empties queues so producers never block on a full queue + drainStop := make(chan struct{}) + go func() { + for { + select { + case <-drainStop: + return + default: + if val, ok := requestQueues.Load(provider); ok { + pq := val.(*ProviderQueue) + select { + case <-pq.queue: + default: + } + } + runtime.Gosched() + } + } + }() + + // Producers: continuously simulate the tryRequest hot path + for i := 0; i < numProducers; i++ { + producerWg.Add(1) + go func() { + defer producerWg.Done() + for { + select { + case <-stop: + return + default: + } + + val, ok := requestQueues.Load(provider) + if !ok { + runtime.Gosched() + continue + } + pq := val.(*ProviderQueue) + + func() { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + }() + + // Re-route check (mirrors tryRequest) + if pq.isClosing() { + if newVal, ok2 := requestQueues.Load(provider); ok2 { + if candidate := newVal.(*ProviderQueue); candidate != pq { + pq = candidate + } + } + // If still closing (RemoveProvider path), just return + if pq.isClosing() { + return + } + } + + msg := &ChannelMessage{} + select { + case pq.queue <- msg: + case <-pq.done: + case <-stop: // unblock immediately when the test signals stop + } + }() + + runtime.Gosched() + } + }() + } + + // Updater: repeatedly performs UpdateProvider-style queue replacements + var updaterWg sync.WaitGroup + updaterWg.Add(1) + go func() { + defer updaterWg.Done() + for i := 0; i < numUpdates; i++ { + val, ok := requestQueues.Load(provider) + if !ok { + continue + } + oldPq := val.(*ProviderQueue) + newPq := makePq() + + // Mirror production UpdateProvider step order exactly: + // Step 2: expose newPq first so stale producers can re-route to it + // once they see oldPq is closing. + requestQueues.Store(provider, newPq) + + // Step 3: transfer buffered messages oldPq → newPq. + drain: + for { + select { + case msg := <-oldPq.queue: + select { + case newPq.queue <- msg: + default: + // newPq full during transfer — mirrors production cancel path. + atomic.AddInt64(&transferDropCount, 1) + } + default: + break drain + } + } + + // Step 4: signal closing — producers holding a stale oldPq ref now + // re-route to newPq (already in the map from step 2). + oldPq.signalClosing() + + time.Sleep(5 * time.Millisecond) + } + }() + + time.Sleep(producerRunTime) + close(stop) + close(drainStop) + producerWg.Wait() + updaterWg.Wait() + + if n := atomic.LoadInt64(&panicCount); n > 0 { + t.Errorf("detected %d panic(s) — fix did not eliminate the concurrent-access race", n) + } else { + t.Logf("confirmed: zero panics across %d producers + %d queue replacements over %v", + numProducers, numUpdates, producerRunTime) + } + if drops := atomic.LoadInt64(&transferDropCount); drops > 0 { + t.Logf("note: %d message(s) dropped during transfer (oldPq had >200 buffered items) — does not affect panic correctness", drops) + } +} + +// ============================================================================= +// RemoveProvider Lifecycle Tests +// +// These tests verify the behavioral contract of RemoveProvider: +// 1. signalClosing() blocks new producers (isClosing() → true) +// 2. Buffered items in the queue get "provider is shutting down" errors +// 3. Workers exit cleanly and the WaitGroup reaches zero +// ============================================================================= + +// TestRemoveProvider_BlocksNewProducers verifies that after signalClosing(), +// isClosing() returns true. Producers check this flag before sending and return +// a "provider is shutting down" error rather than trying to enqueue. +func TestRemoveProvider_BlocksNewProducers(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Sanity: before shutdown, producers can proceed + if pq.isClosing() { + t.Fatal("isClosing() must be false before RemoveProvider runs") + } + + // RemoveProvider step 2: signal closing + pq.signalClosing() + + // New producers must see isClosing() == true and abort + if !pq.isClosing() { + t.Fatal("isClosing() must be true after signalClosing() (RemoveProvider)") + } + + // done must be closed so any producer blocked in the select unblocks immediately + select { + case <-pq.done: + // correct + default: + t.Fatal("pq.done must be closed after signalClosing() so blocking producers unblock") + } + + // CRITICAL: queue channel must remain OPEN — closing it would cause panics in + // any producer that entered the select before seeing isClosing(). + // With the fix, we NEVER close the queue channel. + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + // A select with done closed always takes the done case — safe, no panic + select { + case pq.queue <- &ChannelMessage{}: + case <-pq.done: + } + }() + if panicked { + t.Fatal("queue channel must stay open after signalClosing() — closing it causes panics") + } +} + +// TestRemoveProvider_BufferedRequestsGetErrors verifies the drain contract: +// items queued BEFORE signalClosing fires must each receive a +// "provider is shutting down" error on their Err channel. No client should be +// left hanging. +// +// This test exercises the drain logic directly — the same code path that +// requestWorker executes in its case <-pq.done: branch — to avoid the +// non-deterministic select race where the normal processing path can pick up +// items before done fires. +func TestRemoveProvider_BufferedRequestsGetErrors(t *testing.T) { + const numBuffered = 8 + + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numBuffered+5), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + // Buffer requests — simulates requests already queued when RemoveProvider runs + msgs := make([]*ChannelMessage, numBuffered) + for i := 0; i < numBuffered; i++ { + msgs[i] = newTestChannelMessage(ctx) + pq.queue <- msgs[i] + } + + // RemoveProvider step 2: signal closing + pq.signalClosing() + + // Execute the drain path — exactly what requestWorker does in case <-pq.done: + <-pq.done // fires immediately since signalClosing was already called +drainLoop: + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is shutting down", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + OriginalModelRequested: mod, + }, + } + default: + break drainLoop + } + } + + // Every buffered message must have received a shutdown error + for i, msg := range msgs { + select { + case bifrostErr := <-msg.Err: + if bifrostErr.Error == nil { + t.Errorf("message %d: got nil Error field in BifrostError", i) + continue + } + if bifrostErr.Error.Message != "provider is shutting down" { + t.Errorf("message %d: expected 'provider is shutting down', got %q", + i, bifrostErr.Error.Message) + } + if bifrostErr.ExtraFields.Provider != schemas.OpenAI { + t.Errorf("message %d: expected provider %s, got %s", + i, schemas.OpenAI, bifrostErr.ExtraFields.Provider) + } + if bifrostErr.ExtraFields.RequestType != schemas.ChatCompletionRequest { + t.Errorf("message %d: expected requestType %v, got %v", + i, schemas.ChatCompletionRequest, bifrostErr.ExtraFields.RequestType) + } + default: + t.Errorf("message %d: no error received — client would be left hanging indefinitely", i) + } + } +} + +// TestRemoveProvider_WorkerWaitGroupCompletes verifies that after signalClosing(), +// the worker goroutine decrements the WaitGroup and wg.Wait() returns promptly. +// This mirrors what RemoveProvider does: signal, then Wait() before cleanup. +func TestRemoveProvider_WorkerWaitGroupCompletes(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + var wg sync.WaitGroup + wg.Add(1) + + // Worker goroutine — mirrors requestWorker's WaitGroup contract + go func() { + defer wg.Done() + for { + select { + case r, ok := <-pq.queue: + if !ok { + return + } + _ = r + case <-pq.done: + // Drain remaining (empty in this test) + for { + select { + case <-pq.queue: + default: + return + } + } + } + } + }() + + // Tiny sleep to ensure worker is parked on select before we signal + time.Sleep(10 * time.Millisecond) + + // RemoveProvider step 2: signal closing + pq.signalClosing() + + // RemoveProvider step 3: wait for workers — must complete promptly + waitReturned := make(chan struct{}) + go func() { + wg.Wait() + close(waitReturned) + }() + + select { + case <-waitReturned: + // correct: WaitGroup reached zero after signalClosing() + case <-time.After(2 * time.Second): + t.Fatal("wg.Wait() did not return after signalClosing() — worker is stuck (would deadlock RemoveProvider)") + } +} + +// TestRemoveProvider_ConcurrentNewProducersDuringShutdown verifies that +// concurrent producers trying to enqueue after RemoveProvider calls +// signalClosing() all get safe "provider is shutting down" errors — none panic. +// This tests the TOCTOU window: producer passes isClosing() check, then done fires. +func TestRemoveProvider_ConcurrentNewProducersDuringShutdown(t *testing.T) { + const numProducers = 50 + + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numProducers+10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + var panicCount int64 + var shutdownErrors int64 + var successfulSends int64 + + // Gate: all producers start together after isClosing() passes + passedGate := make(chan struct{}) + var gateOnce sync.Once + shutdownFired := make(chan struct{}) + + var producerWg sync.WaitGroup + + for i := 0; i < numProducers; i++ { + producerWg.Add(1) + go func() { + defer producerWg.Done() + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + }() + + // Each producer checks isClosing() first (mirrors tryRequest) + if pq.isClosing() { + atomic.AddInt64(&shutdownErrors, 1) + return + } + + // Signal that at least one producer passed the isClosing() check + gateOnce.Do(func() { close(passedGate) }) + + // Wait for shutdown to be signaled (the TOCTOU window) + <-shutdownFired + + // Producers now enter the select — with the fix, done is closed but + // queue is NOT closed, so this select is always safe (no panic) + msg := &ChannelMessage{} + select { + case pq.queue <- msg: + atomic.AddInt64(&successfulSends, 1) + case <-pq.done: + atomic.AddInt64(&shutdownErrors, 1) + } + }() + } + + // Wait for at least one producer to pass the isClosing() gate + select { + case <-passedGate: + case <-time.After(2 * time.Second): + t.Fatal("no producer passed the isClosing() check within timeout") + } + + // Signal shutdown (RemoveProvider step 2) — this is the TOCTOU race + pq.signalClosing() + close(shutdownFired) + + producerWg.Wait() + + if n := atomic.LoadInt64(&panicCount); n > 0 { + t.Errorf("detected %d panic(s) — queue must not be closed during concurrent shutdown", n) + } + + t.Logf("result: %d successful sends, %d shutdown errors, %d panics across %d producers", + atomic.LoadInt64(&successfulSends), + atomic.LoadInt64(&shutdownErrors), + atomic.LoadInt64(&panicCount), + numProducers) +} diff --git a/core/changelog.md b/core/changelog.md index 45b09cb9eb..080817cb7d 100644 --- a/core/changelog.md +++ b/core/changelog.md @@ -1,5 +1,5 @@ -- fix: Gemini provider - handle content block tool outputs in Responses API path -- fix: case-insensitive `anthropic-beta` merge in `MergeBetaHeaders` -- fix: Bedrock provider - emit message_stop event for Anthropic invoke stream [@tefimov](https://github.com/tefimov) -- fix: gemini preserves thinkingLevel parameters during round-trip and finish reason mapping -- fix: WebSearch tool argument handling for all clients by removing the Claude Code user agent restriction +- fix: Gemini provider - handle content block tool outputs in Responses API path +- fix: case-insensitive `anthropic-beta` merge in `MergeBetaHeaders` +- fix: Bedrock provider - emit message_stop event for Anthropic invoke stream [@tefimov](https://github.com/tefimov) +- fix: gemini preserves thinkingLevel parameters during round-trip and finish reason mapping +- fix: WebSearch tool argument handling for all clients by removing the Claude Code user agent restriction diff --git a/core/go.mod b/core/go.mod index 924e204ea2..013296f021 100644 --- a/core/go.mod +++ b/core/go.mod @@ -1,18 +1,18 @@ module github.com/maximhq/bifrost/core -go 1.26.2 +go 1.26.1 require ( cloud.google.com/go v0.123.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 github.com/andybalholm/brotli v1.2.0 - github.com/aws/aws-sdk-go-v2 v1.41.3 + github.com/aws/aws-sdk-go-v2 v1.41.5 github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 github.com/aws/aws-sdk-go-v2/config v1.32.11 - github.com/aws/aws-sdk-go-v2/credentials v1.19.11 - github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 - github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 + github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 + github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 github.com/aws/smithy-go v1.24.2 github.com/bytedance/sonic v1.15.0 github.com/fasthttp/websocket v1.5.12 @@ -35,18 +35,18 @@ require ( cloud.google.com/go/compute/metadata v0.9.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 // indirect - github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.2 // indirect github.com/bytedance/gopkg v0.1.3 // indirect diff --git a/core/go.sum b/core/go.sum index a01765f093..685035b381 100644 --- a/core/go.sum +++ b/core/go.sum @@ -16,42 +16,38 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgv github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= -github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= -github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= +github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= -github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= -github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 h1:CjMzUs78RDDv4ROu3JnJn/Ig1r6ZD7/T2DXLLRpejic= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16/go.mod h1:uVW4OLBqbJXSHJYA9svT9BluSvvwbzLQ2Crf6UPzR3c= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 h1:DIBqIrJ7hv+e4CmIk2z3pyKT+3B6qVMgRsawHiR3qso= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7/go.mod h1:vLm00xmBke75UmpNvOcZQ/Q30ZFjbczeLFqGx5urmGo= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= -github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 h1:SWTxh/EcUCDVqi/0s26V6pVUq0BBG7kx0tDTmF/hCgA= -github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= -github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= -github.com/aws/aws-sdk-go-v2/service/signin v1.0.7/go.mod h1:idzZ7gmDeqeNrSPkdbtMp9qWMgcBwykA7P7Rzh5DXVU= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= -github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= -github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.13 h1:JRaIgADQS/U6uXDqlPiefP32yXTda7Kqfx+LgspooZM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 h1:ZlvrNcHSFFWURB8avufQq9gFsheUgjVD9536obIknfM= +github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 h1:HwxWTbTrIHm5qY+CAEur0s/figc3qwvLWsNkF4RPToo= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= diff --git a/core/internal/llmtests/account.go b/core/internal/llmtests/account.go index ac850830fd..e4044e0e30 100644 --- a/core/internal/llmtests/account.go +++ b/core/internal/llmtests/account.go @@ -15,80 +15,89 @@ import ( const Concurrency = 4 +// Replicate test key names (see replicateProviderTestKeys). +// ListModels uses the deployments API via the list-models key; all other operations use the inference key. +const ( + ReplicateKeyNameListModels = "replicate-list-models-deployments" + ReplicateKeyNameInference = "replicate-inference" +) + // ProviderOpenAICustom represents the custom OpenAI provider for testing const ProviderOpenAICustom = schemas.ModelProvider("openai-custom") // TestScenarios defines the comprehensive test scenarios type TestScenarios struct { - TextCompletion bool - TextCompletionStream bool - SimpleChat bool - CompletionStream bool - MultiTurnConversation bool - ToolCalls bool - ToolCallsStreaming bool // Streaming tool calls functionality + TextCompletion bool + TextCompletionStream bool + SimpleChat bool + CompletionStream bool + MultiTurnConversation bool + ToolCalls bool + ToolCallsStreaming bool // Streaming tool calls functionality MultipleToolCalls bool MultipleToolCallsStreaming bool // Streaming multiple tool calls (some providers only return 1 tool call in streaming) End2EndToolCalling bool - AutomaticFunctionCall bool - ImageURL bool - ImageBase64 bool - MultipleImages bool - FileBase64 bool - FileURL bool - CompleteEnd2End bool - SpeechSynthesis bool // Text-to-speech functionality - SpeechSynthesisStream bool // Streaming text-to-speech functionality - Transcription bool // Speech-to-text functionality - TranscriptionStream bool // Streaming speech-to-text functionality - Embedding bool // Embedding functionality - Reasoning bool // Reasoning/thinking functionality via Responses API - PromptCaching bool // Prompt caching functionality - ListModels bool // List available models functionality - ImageGeneration bool // Image generation functionality - ImageGenerationStream bool // Streaming image generation functionality - ImageEdit bool // Image edit functionality - ImageEditStream bool // Streaming image edit functionality - ImageVariation bool // Image variation functionality - ImageVariationStream bool // Streaming image variation functionality (if supported) - VideoGeneration bool // Video generation functionality - VideoRetrieve bool // Video retrieve functionality - VideoRemix bool // Video remix functionality (OpenAI only) - VideoDownload bool // Video download functionality - VideoList bool // Video list functionality - VideoDelete bool // Video delete functionality - BatchCreate bool // Batch API create functionality - BatchList bool // Batch API list functionality - BatchRetrieve bool // Batch API retrieve functionality - BatchCancel bool // Batch API cancel functionality - BatchResults bool // Batch API results functionality - FileUpload bool // File API upload functionality - FileList bool // File API list functionality - FileRetrieve bool // File API retrieve functionality - FileDelete bool // File API delete functionality - FileContent bool // File API content download functionality - FileBatchInput bool // Whether batch create supports file-based input (InputFileID) - CountTokens bool // Count tokens functionality - ChatAudio bool // Chat completion with audio input/output functionality - StructuredOutputs bool // Structured outputs (JSON schema) functionality - WebSearchTool bool // Web search tool functionality - ContainerCreate bool // Container API create functionality - ContainerList bool // Container API list functionality - ContainerRetrieve bool // Container API retrieve functionality - ContainerDelete bool // Container API delete functionality - ContainerFileCreate bool // Container File API create functionality - ContainerFileList bool // Container File API list functionality - ContainerFileRetrieve bool // Container File API retrieve functionality - ContainerFileContent bool // Container File API content functionality - ContainerFileDelete bool // Container File API delete functionality - PassThroughExtraParams bool // Pass through extra params functionality - Rerank bool // Rerank functionality - PassthroughAPI bool // Raw HTTP passthrough API (Passthrough + PassthroughStream) - WebSocketResponses bool // WebSocket Responses API mode - Realtime bool // Realtime API (bidirectional audio/text) - Compaction bool // Server-side compaction (context management) - InterleavedThinking bool // Interleaved thinking between tool calls (beta) - FastMode bool // Fast mode for Opus 4.6 (beta: research preview) + AutomaticFunctionCall bool + ImageURL bool + ImageBase64 bool + MultipleImages bool + FileBase64 bool + FileURL bool + CompleteEnd2End bool + SpeechSynthesis bool // Text-to-speech functionality + SpeechSynthesisStream bool // Streaming text-to-speech functionality + Transcription bool // Speech-to-text functionality + TranscriptionStream bool // Streaming speech-to-text functionality + Embedding bool // Embedding functionality + Reasoning bool // Reasoning/thinking functionality via Responses API + PromptCaching bool // Prompt caching functionality + ListModels bool // List available models functionality + ImageGeneration bool // Image generation functionality + ImageGenerationStream bool // Streaming image generation functionality + ImageEdit bool // Image edit functionality + ImageEditStream bool // Streaming image edit functionality + ImageVariation bool // Image variation functionality + ImageVariationStream bool // Streaming image variation functionality (if supported) + VideoGeneration bool // Video generation functionality + VideoRetrieve bool // Video retrieve functionality + VideoRemix bool // Video remix functionality (OpenAI only) + VideoDownload bool // Video download functionality + VideoList bool // Video list functionality + VideoDelete bool // Video delete functionality + BatchCreate bool // Batch API create functionality + BatchList bool // Batch API list functionality + BatchRetrieve bool // Batch API retrieve functionality + BatchCancel bool // Batch API cancel functionality + BatchResults bool // Batch API results functionality + FileUpload bool // File API upload functionality + FileList bool // File API list functionality + FileRetrieve bool // File API retrieve functionality + FileDelete bool // File API delete functionality + FileContent bool // File API content download functionality + FileBatchInput bool // Whether batch create supports file-based input (InputFileID) + CountTokens bool // Count tokens functionality + ChatAudio bool // Chat completion with audio input/output functionality + StructuredOutputs bool // Structured outputs (JSON schema) functionality + WebSearchTool bool // Web search tool functionality + ContainerCreate bool // Container API create functionality + ContainerList bool // Container API list functionality + ContainerRetrieve bool // Container API retrieve functionality + ContainerDelete bool // Container API delete functionality + ContainerFileCreate bool // Container File API create functionality + ContainerFileList bool // Container File API list functionality + ContainerFileRetrieve bool // Container File API retrieve functionality + ContainerFileContent bool // Container File API content functionality + ContainerFileDelete bool // Container File API delete functionality + PassThroughExtraParams bool // Pass through extra params functionality + Rerank bool // Rerank functionality + PassthroughAPI bool // Raw HTTP passthrough API (Passthrough + PassthroughStream) + WebSocketResponses bool // WebSocket Responses API mode + Realtime bool // Realtime API (bidirectional audio/text) + Compaction bool // Server-side compaction (context management) + InterleavedThinking bool // Interleaved thinking between tool calls (beta) + FastMode bool // Fast mode for Opus 4.6 (beta: research preview) + EagerInputStreaming bool // Fine-grained tool input streaming (Anthropic fine-grained-tool-streaming-2025-05-14) + ServerToolsViaOpenAIEndpoint bool // Anthropic server-tool shapes in tools[] via /v1/chat/completions (web_search / web_fetch / code_execution) } // ComprehensiveTestConfig extends TestConfig with additional scenarios @@ -173,6 +182,34 @@ func (account *ComprehensiveTestAccount) GetConfiguredProviders() ([]schemas.Mod }, nil } +// replicateProviderTestKeys returns the two Replicate keys used by comprehensive tests: +func replicateProviderTestKeys() []schemas.Key { + return []schemas.Key{ + { + Name: ReplicateKeyNameListModels, + Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), + Models: []string{"*"}, + Weight: 0, + UseForBatchAPI: bifrost.Ptr(false), + ReplicateKeyConfig: &schemas.ReplicateKeyConfig{UseDeploymentsEndpoint: true}, + }, + { + Name: ReplicateKeyNameInference, + Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), + Models: []string{"*"}, + Weight: 1.0, + UseForBatchAPI: bifrost.Ptr(true), + ReplicateKeyConfig: nil, + }, + } +} + +// ReplicateDirectKeyForListModels returns the key used for Replicate ListModels (deployments endpoint). +// List-models tests set it on the context as schemas.BifrostContextKeyDirectKey so Bifrost passes only this key. +func ReplicateDirectKeyForListModels() schemas.Key { + return replicateProviderTestKeys()[0] +} + // GetKeysForProvider returns the API keys and associated models for a given provider. func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) { switch providerKey { @@ -180,7 +217,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -189,7 +226,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENAI_API_KEY"), // Use GROQ API key for OpenAI-compatible endpoint - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -198,7 +235,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.ANTHROPIC_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -206,38 +243,38 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, case schemas.Bedrock: return []schemas.Key{ { - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_ARN"), - Deployments: map[string]string{ - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, }, { - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: map[string]string{ + "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", + "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", + "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", + }, BedrockKeyConfig: &schemas.BedrockKeyConfig{ AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), Region: schemas.NewEnvVar(getEnvWithDefault("AWS_REGION", "us-east-1")), ARN: schemas.NewEnvVar("env.AWS_BEDROCK_ARN"), - Deployments: map[string]string{ - "claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", - "claude-3.7-sonnet": "us.anthropic.claude-3-7-sonnet-20250219-v1:0", - "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", - "claude-4.5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", - "claude-4.5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -256,7 +293,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.COHERE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -265,20 +302,20 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "gpt-4o", + "gpt-4o-backup": "gpt-4o-3", + "claude-opus-4-5": "claude-opus-4-5", + "o1": "o1", + "gpt-image-1": "gpt-image-1", + "text-embedding-ada-002": "text-embedding-ada-002", + "sora-2": "sora-2", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ - Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), - APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o", - "gpt-4o-backup": "gpt-4o-3", - "claude-opus-4-5": "claude-opus-4-5", - "o1": "o1", - "gpt-image-1": "gpt-image-1", - "text-embedding-ada-002": "text-embedding-ada-002", - "sora-2": "sora-2", - }, + Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), + APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), ClientID: schemas.NewEnvVar("env.AZURE_CLIENT_ID"), ClientSecret: schemas.NewEnvVar("env.AZURE_CLIENT_SECRET"), TenantID: schemas.NewEnvVar("env.AZURE_TENANT_ID"), @@ -287,16 +324,17 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, }, { Value: *schemas.NewEnvVar("env.AZURE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "whisper": "whisper", + "whisper-1": "whisper", + "gpt-4o-mini-tts": "gpt-4o-mini-tts", + "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", + }, AzureKeyConfig: &schemas.AzureKeyConfig{ Endpoint: *schemas.NewEnvVar("env.AZURE_ENDPOINT"), APIVersion: schemas.NewEnvVar("env.AZURE_API_VERSION"), - Deployments: map[string]string{ - "whisper": "whisper", - "gpt-4o-mini-tts": "gpt-4o-mini-tts", - "gpt-4o-mini-audio-preview": "gpt-4o-mini-audio-preview", - }, }, }, }, nil @@ -306,7 +344,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), - Models: []string{"text-multilingual-embedding-002", "google/gemini-2.0-flash-001", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004"}, + Models: []string{"text-multilingual-embedding-002", "gemini-2.5-pro", "gemini-2.5-flash-image", "imagen-4.0-generate-001", "imagen-3.0-capability-001", "semantic-ranker-default@latest", "semantic-ranker-default-004"}, Weight: 1.0, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), @@ -330,15 +368,15 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), Models: []string{"claude-sonnet-4-5", "claude-4.5-haiku", "claude-opus-4-5"}, Weight: 1.0, + Aliases: schemas.KeyAliases{ + "claude-sonnet-4-5": "claude-sonnet-4-5", + "claude-4.5-haiku": "claude-haiku-4-5@20251001", + "claude-opus-4-5": "claude-opus-4-5", + }, VertexKeyConfig: &schemas.VertexKeyConfig{ ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), Region: *schemas.NewEnvVar(getEnvWithDefault("VERTEX_REGION_ANTHROPIC", "us-east5")), AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), - Deployments: map[string]string{ - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-4.5-haiku": "claude-haiku-4-5@20251001", - "claude-opus-4-5": "claude-opus-4-5", - }, }, UseForBatchAPI: bifrost.Ptr(true), }, @@ -347,7 +385,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.MISTRAL_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -356,7 +394,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.GROQ_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -365,7 +403,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.PARASAIL_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -374,7 +412,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.ELEVENLABS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -383,7 +421,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.PERPLEXITY_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -392,7 +430,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.CEREBRAS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -401,7 +439,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.GEMINI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -410,7 +448,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.OPENROUTER_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -419,7 +457,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.HUGGING_FACE_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -428,7 +466,7 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.NEBIUS_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -437,25 +475,18 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(ctx context.Context, return []schemas.Key{ { Value: *schemas.NewEnvVar("env.XAI_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, }, nil case schemas.Replicate: - return []schemas.Key{ - { - Value: *schemas.NewEnvVar("env.REPLICATE_API_KEY"), - Models: []string{}, - Weight: 1.0, - UseForBatchAPI: bifrost.Ptr(true), - }, - }, nil + return replicateProviderTestKeys(), nil case schemas.Runway: return []schemas.Key{ { Value: *schemas.NewEnvVar("env.RUNWAY_API_KEY"), - Models: []string{}, + Models: []string{"*"}, Weight: 1.0, UseForBatchAPI: bifrost.Ptr(true), }, @@ -826,53 +857,54 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageVariationModel: "dall-e-2", ChatAudioModel: "gpt-4o-mini-audio-preview", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - TextCompletionStream: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + TextCompletionStream: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, // OpenAI supports TTS - SpeechSynthesisStream: true, // OpenAI supports streaming TTS - Transcription: true, // OpenAI supports STT with Whisper - TranscriptionStream: true, // OpenAI supports streaming STT - ImageGeneration: true, // OpenAI supports image generation with DALL-E - ImageGenerationStream: true, // OpenAI supports streaming image generation - ImageEdit: true, // OpenAI supports image editing - ImageEditStream: true, // OpenAI supports streaming image editing - ImageVariation: true, // OpenAI supports image variation - ImageVariationStream: false, // OpenAI does not support streaming image variation - Embedding: true, - Reasoning: true, // OpenAI supports reasoning via o1 models - ListModels: true, - BatchCreate: true, // OpenAI supports batch API - BatchList: true, // OpenAI supports batch API - BatchRetrieve: true, // OpenAI supports batch API - BatchCancel: true, // OpenAI supports batch API - BatchResults: true, // OpenAI supports batch API - FileUpload: true, // OpenAI supports file API - FileList: true, // OpenAI supports file API - FileRetrieve: true, // OpenAI supports file API - FileDelete: true, // OpenAI supports file API - FileContent: true, // OpenAI supports file API - ChatAudio: true, // OpenAI supports chat audio - ContainerCreate: true, // OpenAI supports container API - ContainerList: true, // OpenAI supports container API - ContainerRetrieve: true, // OpenAI supports container API - ContainerDelete: true, // OpenAI supports container API - ContainerFileCreate: true, // OpenAI supports container file API - ContainerFileList: true, // OpenAI supports container file API - ContainerFileRetrieve: true, // OpenAI supports container file API - ContainerFileContent: true, // OpenAI supports container file API - ContainerFileDelete: true, // OpenAI supports container file API + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, // OpenAI supports TTS + SpeechSynthesisStream: true, // OpenAI supports streaming TTS + Transcription: true, // OpenAI supports STT with Whisper + TranscriptionStream: true, // OpenAI supports streaming STT + ImageGeneration: true, // OpenAI supports image generation with DALL-E + ImageGenerationStream: true, // OpenAI supports streaming image generation + ImageEdit: true, // OpenAI supports image editing + ImageEditStream: true, // OpenAI supports streaming image editing + ImageVariation: true, // OpenAI supports image variation + ImageVariationStream: false, // OpenAI does not support streaming image variation + Embedding: true, + Reasoning: true, // OpenAI supports reasoning via o1 models + ListModels: true, + BatchCreate: true, // OpenAI supports batch API + BatchList: true, // OpenAI supports batch API + BatchRetrieve: true, // OpenAI supports batch API + BatchCancel: true, // OpenAI supports batch API + BatchResults: true, // OpenAI supports batch API + FileUpload: true, // OpenAI supports file API + FileList: true, // OpenAI supports file API + FileRetrieve: true, // OpenAI supports file API + FileDelete: true, // OpenAI supports file API + FileContent: true, // OpenAI supports file API + ChatAudio: true, // OpenAI supports chat audio + ContainerCreate: true, // OpenAI supports container API + ContainerList: true, // OpenAI supports container API + ContainerRetrieve: true, // OpenAI supports container API + ContainerDelete: true, // OpenAI supports container API + ContainerFileCreate: true, // OpenAI supports container file API + ContainerFileList: true, // OpenAI supports container file API + ContainerFileRetrieve: true, // OpenAI supports container file API + ContainerFileContent: true, // OpenAI supports container file API + ContainerFileDelete: true, // OpenAI supports container file API }, Fallbacks: []schemas.Fallback{ {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, @@ -883,37 +915,38 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "claude-3-7-sonnet-20250219", TextModel: "", // Anthropic doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - PromptCaching: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Anthropic does not support image editing - ImageEditStream: false, // Anthropic does not support streaming image editing - ImageVariation: false, // Anthropic does not support image variation - ImageVariationStream: false, // Anthropic does not support streaming image variation - ListModels: true, - BatchCreate: true, // Anthropic supports batch API - BatchList: true, // Anthropic supports batch API - BatchRetrieve: true, // Anthropic supports batch API - BatchCancel: true, // Anthropic supports batch API - BatchResults: true, // Anthropic supports batch API + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + PromptCaching: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Anthropic does not support image editing + ImageEditStream: false, // Anthropic does not support streaming image editing + ImageVariation: false, // Anthropic does not support image variation + ImageVariationStream: false, // Anthropic does not support streaming image variation + ListModels: true, + BatchCreate: true, // Anthropic supports batch API + BatchList: true, // Anthropic supports batch API + BatchRetrieve: true, // Anthropic supports batch API + BatchCancel: true, // Anthropic supports batch API + BatchResults: true, // Anthropic supports batch API }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -926,42 +959,43 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageEditModel: "amazon.titan-image-generator-v1", ImageVariationModel: "amazon.titan-image-generator-v1", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported for Claude - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported for Claude + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - PromptCaching: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: true, // Bedrock supports image editing - ImageEditStream: false, // Bedrock does not support streaming image editing - ImageVariation: true, // Bedrock supports image variation - ImageVariationStream: false, // Bedrock does not support streaming image variation - ListModels: true, - BatchCreate: true, // Bedrock supports batch via Model Invocation Jobs (requires S3 config) - BatchList: true, // Bedrock supports listing batch jobs - BatchRetrieve: true, // Bedrock supports retrieving batch jobs - BatchCancel: true, // Bedrock supports stopping batch jobs - BatchResults: true, // Bedrock batch results via S3 - FileUpload: true, // Bedrock file upload to S3 (requires S3 config) - FileList: true, // Bedrock file list from S3 (requires S3 config) - FileRetrieve: true, // Bedrock file retrieve from S3 (requires S3 config) - FileDelete: true, // Bedrock file delete from S3 (requires S3 config) - FileContent: true, // Bedrock file content from S3 (requires S3 config) + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + PromptCaching: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: true, // Bedrock supports image editing + ImageEditStream: false, // Bedrock does not support streaming image editing + ImageVariation: true, // Bedrock supports image variation + ImageVariationStream: false, // Bedrock does not support streaming image variation + ListModels: true, + BatchCreate: true, // Bedrock supports batch via Model Invocation Jobs (requires S3 config) + BatchList: true, // Bedrock supports listing batch jobs + BatchRetrieve: true, // Bedrock supports retrieving batch jobs + BatchCancel: true, // Bedrock supports stopping batch jobs + BatchResults: true, // Bedrock batch results via S3 + FileUpload: true, // Bedrock file upload to S3 (requires S3 config) + FileList: true, // Bedrock file list from S3 (requires S3 config) + FileRetrieve: true, // Bedrock file retrieve from S3 (requires S3 config) + FileDelete: true, // Bedrock file delete from S3 (requires S3 config) + FileContent: true, // Bedrock file content from S3 (requires S3 config) }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -972,31 +1006,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "command-a-03-2025", TextModel: "", // Cohere focuses on chat Scenarios: TestScenarios{ - TextCompletion: false, // Not typical for Cohere - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: false, // May not support automatic - ImageURL: false, // Check if supported - ImageBase64: false, // Check if supported - MultipleImages: false, // Check if supported - CompleteEnd2End: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Cohere does not support image editing - ImageEditStream: false, // Cohere does not support streaming image editing - ImageVariation: false, // Cohere does not support image variation - ImageVariationStream: false, // Cohere does not support streaming image variation - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: false, // May not support automatic + ImageURL: false, // Check if supported + ImageBase64: false, // Check if supported + MultipleImages: false, // Check if supported + CompleteEnd2End: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Cohere does not support image editing + ImageEditStream: false, // Cohere does not support streaming image editing + ImageVariation: false, // Cohere does not support image variation + ImageVariationStream: false, // Cohere does not support streaming image variation + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1012,42 +1047,43 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "gpt-image-1", ImageEditModel: "dall-e-2", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, // Supported via gpt-4o-mini-tts - SpeechSynthesisStream: true, // Supported via gpt-4o-mini-tts - Transcription: true, // Supported via whisper-1 - TranscriptionStream: false, // Not properly supported yet by Azure - Embedding: true, - ImageGeneration: false, // Skipped for Azure - ImageGenerationStream: false, // Skipped for Azure - ImageEdit: true, // Azure supports image editing - ImageEditStream: true, // Azure supports streaming image editing - ImageVariation: false, // Azure does not support image variation - ImageVariationStream: false, // Azure does not support streaming image variation - ListModels: true, - BatchCreate: true, // Azure supports batch API - BatchList: true, // Azure supports batch API - BatchRetrieve: true, // Azure supports batch API - BatchCancel: true, // Azure supports batch API - BatchResults: true, // Azure supports batch API - FileUpload: true, // Azure supports file API - FileList: true, // Azure supports file API - FileRetrieve: true, // Azure supports file API - FileDelete: true, // Azure supports file API - FileContent: true, // Azure supports file API - ChatAudio: true, // Azure supports chat audio + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, // Supported via gpt-4o-mini-tts + SpeechSynthesisStream: true, // Supported via gpt-4o-mini-tts + Transcription: true, // Supported via whisper-1 + TranscriptionStream: false, // Not properly supported yet by Azure + Embedding: true, + ImageGeneration: false, // Skipped for Azure + ImageGenerationStream: false, // Skipped for Azure + ImageEdit: true, // Azure supports image editing + ImageEditStream: true, // Azure supports streaming image editing + ImageVariation: false, // Azure does not support image variation + ImageVariationStream: false, // Azure does not support streaming image variation + ListModels: true, + BatchCreate: true, // Azure supports batch API + BatchList: true, // Azure supports batch API + BatchRetrieve: true, // Azure supports batch API + BatchCancel: true, // Azure supports batch API + BatchResults: true, // Azure supports batch API + FileUpload: true, // Azure supports file API + FileList: true, // Azure supports file API + FileRetrieve: true, // Azure supports file API + FileDelete: true, // Azure supports file API + FileContent: true, // Azure supports file API + ChatAudio: true, // Azure supports chat audio }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1061,31 +1097,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "imagen-4.0-generate-001", ImageEditModel: "imagen-4.0-generate-001", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, // Vertex supports image editing - ImageEditStream: false, // Vertex does not support streaming image editing - ImageVariation: false, // Vertex does not support image variation - ImageVariationStream: false, // Vertex does not support streaming image variation - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, // Vertex supports image editing + ImageEditStream: false, // Vertex does not support streaming image editing + ImageVariation: false, // Vertex does not support image variation + ImageVariationStream: false, // Vertex does not support streaming image variation + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: true, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1097,30 +1134,31 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "", // Mistral focuses on chat TranscriptionModel: "voxtral-mini-latest", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: true, // Supported via voxtral-mini-latest - TranscriptionStream: true, // Supported via voxtral-mini-latest - Embedding: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Mistral does not support image editing - ImageEditStream: false, // Mistral does not support streaming image editing - ImageVariation: false, // Mistral does not support image variation - ImageVariationStream: false, // Mistral does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: true, // Supported via voxtral-mini-latest + TranscriptionStream: true, // Supported via voxtral-mini-latest + Embedding: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Mistral does not support image editing + ImageEditStream: false, // Mistral does not support streaming image editing + ImageVariation: false, // Mistral does not support image variation + ImageVariationStream: false, // Mistral does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1131,31 +1169,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama3.2", TextModel: "", // Ollama focuses on chat Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Ollama does not support image editing - ImageEditStream: false, // Ollama does not support streaming image editing - ImageVariation: false, // Ollama does not support image variation - ImageVariationStream: false, // Ollama does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Ollama does not support image editing + ImageEditStream: false, // Ollama does not support streaming image editing + ImageVariation: false, // Ollama does not support image variation + ImageVariationStream: false, // Ollama does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1166,31 +1205,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama-3.3-70b-versatile", TextModel: "", // Groq doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // Groq does not support image editing - ImageEditStream: false, // Groq does not support streaming image editing - ImageVariation: false, // Groq does not support image variation - ImageVariationStream: false, // Groq does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // Groq does not support image editing + ImageEditStream: false, // Groq does not support streaming image editing + ImageVariation: false, // Groq does not support image variation + ImageVariationStream: false, // Groq does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1231,31 +1271,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "llama-3.3-70b-versatile", TextModel: "", // Custom OpenAI instance doesn't support text completion Scenarios: TestScenarios{ - TextCompletion: false, - SimpleChat: true, // Enable simple chat for testing - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, + SimpleChat: true, // Enable simple chat for testing + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, - ImageGeneration: false, // ProviderOpenAICustom does not support image generation - ImageGenerationStream: false, // ProviderOpenAICustom does not support streaming image generation - ImageEdit: false, // ProviderOpenAICustom does not support image editing - ImageEditStream: false, // ProviderOpenAICustom does not support streaming image editing - ImageVariation: false, // ProviderOpenAICustom does not support image variation - ImageVariationStream: false, // ProviderOpenAICustom does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, + ImageGeneration: false, // ProviderOpenAICustom does not support image generation + ImageGenerationStream: false, // ProviderOpenAICustom does not support streaming image generation + ImageEdit: false, // ProviderOpenAICustom does not support image editing + ImageEditStream: false, // ProviderOpenAICustom does not support streaming image editing + ImageVariation: false, // ProviderOpenAICustom does not support image variation + ImageVariationStream: false, // ProviderOpenAICustom does not support streaming image variation + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1271,41 +1312,42 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ImageGenerationModel: "imagen-4.0-generate-001", ImageEditModel: "imagen-4.0-generate-001", Scenarios: TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: true, - SpeechSynthesisStream: true, - Transcription: true, - TranscriptionStream: true, - Embedding: true, - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: true, // Gemini supports image editing - ImageEditStream: false, // Gemini does not support streaming image editing - ImageVariation: false, // Gemini does not support image variation - ImageVariationStream: false, // Gemini does not support streaming image variation - ListModels: true, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: false, // Gemini doesn't support direct content download + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: true, // Gemini supports image editing + ImageEditStream: false, // Gemini does not support streaming image editing + ImageVariation: false, // Gemini does not support image variation + ImageVariationStream: false, // Gemini does not support streaming image variation + ListModels: true, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: false, // Gemini doesn't support direct content download }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1316,31 +1358,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ ChatModel: "openai/gpt-4o", TextModel: "google/gemini-2.5-flash", Scenarios: TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // OpenRouter does not support image editing - ImageEditStream: false, // OpenRouter does not support streaming image editing - ImageVariation: false, // OpenRouter does not support image variation - ImageVariationStream: false, // OpenRouter does not support streaming image variation - SpeechSynthesis: false, - SpeechSynthesisStream: false, - Transcription: false, - TranscriptionStream: false, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // OpenRouter does not support image editing + ImageEditStream: false, // OpenRouter does not support streaming image editing + ImageVariation: false, // OpenRouter does not support image variation + ImageVariationStream: false, // OpenRouter does not support streaming image variation + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: false, + TranscriptionStream: false, + Embedding: false, + ListModels: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, @@ -1394,31 +1437,32 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "", // XAI focuses on chat ImageGenerationModel: "grok-2-image", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, // Not supported - ImageGeneration: true, - ImageGenerationStream: false, - ImageEdit: false, // XAI does not support image editing - ImageEditStream: false, // XAI does not support streaming image editing - ImageVariation: false, // XAI does not support image variation - ImageVariationStream: false, // XAI does not support streaming image variation - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, // Not supported + ImageGeneration: true, + ImageGenerationStream: false, + ImageEdit: false, // XAI does not support image editing + ImageEditStream: false, // XAI does not support streaming image editing + ImageVariation: false, // XAI does not support image variation + ImageVariationStream: false, // XAI does not support streaming image variation + ListModels: true, }, }, { @@ -1427,27 +1471,28 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ TextModel: "openai/gpt-4.1-mini", ImageGenerationModel: "black-forest-labs/flux-dev", Scenarios: TestScenarios{ - TextCompletion: false, // Not typical - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + TextCompletion: false, // Not typical + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: false, // Not supported - TranscriptionStream: false, // Not supported - Embedding: false, // Not supported - ListModels: true, - ImageGeneration: true, - ImageGenerationStream: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: false, // Not supported + TranscriptionStream: false, // Not supported + Embedding: false, // Not supported + ListModels: true, + ImageGeneration: true, + ImageGenerationStream: false, }, }, { Provider: schemas.VLLM, @@ -1456,27 +1501,28 @@ var AllProviderConfigs = []ComprehensiveTestConfig{ EmbeddingModel: "Qwen/Qwen3-Embedding-0.6B", TranscriptionModel: "openai/whisper-small", Scenarios: TestScenarios{ - SpeechSynthesis: false, // Not supported - SpeechSynthesisStream: false, // Not supported - Transcription: true, // VLLM supports transcription - TranscriptionStream: true, // VLLM supports transcription streaming - Embedding: true, // VLLM supports embedding - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, // VLLM does not support image editing - ImageEditStream: false, // VLLM does not support streaming image editing - ImageVariation: false, // VLLM does not support image variation - ImageVariationStream: false, // VLLM does not support streaming image variation - ListModels: true, - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, + SpeechSynthesis: false, // Not supported + SpeechSynthesisStream: false, // Not supported + Transcription: true, // VLLM supports transcription + TranscriptionStream: true, // VLLM supports transcription streaming + Embedding: true, // VLLM supports embedding + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, // VLLM does not support image editing + ImageEditStream: false, // VLLM does not support streaming image editing + ImageVariation: false, // VLLM does not support image variation + ImageVariationStream: false, // VLLM does not support streaming image variation + ListModels: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, + End2EndToolCalling: true, }, Fallbacks: []schemas.Fallback{ {Provider: schemas.OpenAI, Model: "gpt-4o-mini"}, diff --git a/core/internal/llmtests/complete_end_to_end.go b/core/internal/llmtests/complete_end_to_end.go index 607ca1ca86..399c79845a 100644 --- a/core/internal/llmtests/complete_end_to_end.go +++ b/core/internal/llmtests/complete_end_to_end.go @@ -70,7 +70,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C ToolChoice: &schemas.ChatToolChoice{ ChatToolChoiceStr: bifrost.Ptr(string(schemas.ChatToolChoiceTypeRequired)), }, - MaxCompletionTokens: bifrost.Ptr(150), + MaxCompletionTokens: bifrost.Ptr(400), }, Fallbacks: testConfig.Fallbacks, } @@ -88,7 +88,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C ToolChoice: &schemas.ResponsesToolChoice{ ResponsesToolChoiceStr: bifrost.Ptr(string(schemas.ResponsesToolChoiceTypeRequired)), }, - MaxOutputTokens: bifrost.Ptr(150), + MaxOutputTokens: bifrost.Ptr(400), }, } return client.ResponsesRequest(bfCtx, responsesReq) @@ -205,7 +205,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: testConfig.ChatModel, Input: chatConversationHistory, Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(200), + MaxCompletionTokens: bifrost.Ptr(400), }, Fallbacks: testConfig.Fallbacks, } @@ -219,7 +219,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: testConfig.ChatModel, Input: responsesConversationHistory, Params: &schemas.ResponsesParameters{ - MaxOutputTokens: bifrost.Ptr(200), + MaxOutputTokens: bifrost.Ptr(400), }, } return client.ResponsesRequest(bfCtx, responsesReq) @@ -343,7 +343,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: model, Input: chatConversationHistory, Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(200), + MaxCompletionTokens: bifrost.Ptr(400), }, Fallbacks: testConfig.Fallbacks, } @@ -357,7 +357,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Model: model, Input: responsesConversationHistory, Params: &schemas.ResponsesParameters{ - MaxOutputTokens: bifrost.Ptr(200), + MaxOutputTokens: bifrost.Ptr(400), }, } return client.ResponsesRequest(bfCtx, responsesReq) diff --git a/core/internal/llmtests/eager_input_streaming.go b/core/internal/llmtests/eager_input_streaming.go new file mode 100644 index 0000000000..0f074c46af --- /dev/null +++ b/core/internal/llmtests/eager_input_streaming.go @@ -0,0 +1,134 @@ +package llmtests + +import ( + "context" + "os" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunEagerInputStreamingTest tests that setting eager_input_streaming: true on +// a custom tool succeeds end-to-end against the target Anthropic-family +// provider. Per Table 20 (verified against A overview + B-header), the +// fine-grained-tool-streaming-2025-05-14 beta is supported on Anthropic, +// Bedrock, Vertex, and Azure. +// +// The test verifies: +// 1. The request is accepted (no upstream 400 — which would indicate the +// fine-grained-tool-streaming-2025-05-14 beta header wasn't injected or +// is rejected by the target provider). +// 2. The stream produces a tool call with a valid JSON arguments payload. +// 3. The response is otherwise well-formed. +// +// This intentionally runs across all four providers (no single-provider gate +// unlike RunFastModeTest, which is Opus-4.6-only). +func RunEagerInputStreamingTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.EagerInputStreaming { + t.Logf("EagerInputStreaming not supported for provider %s", testConfig.Provider) + return + } + + t.Run("EagerInputStreaming", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + chatTool := GetSampleChatTool(SampleToolTypeWeather) + // Opt the tool into fine-grained input streaming. The neutral flag + // on ChatTool is promoted through ToAnthropicChatRequest, which also + // triggers the fine-grained-tool-streaming-2025-05-14 beta header. + eager := true + chatTool.EagerInputStreaming = &eager + + chatMessages := []schemas.ChatMessage{ + CreateBasicChatMessage("What's the weather like in San Francisco? answer in celsius"), + } + + request := &schemas.BifrostChatRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: chatMessages, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(200), + Tools: []schemas.ChatTool{*chatTool}, + }, + Fallbacks: testConfig.Fallbacks, + } + + retryConfig := StreamingRetryConfig() + retryContext := TestRetryContext{ + ScenarioName: "EagerInputStreaming", + ExpectedBehavior: map[string]interface{}{ + "should_stream_content": true, + "should_have_tool_calls": true, + "tool_name": "get_weather", + }, + TestMetadata: map[string]interface{}{ + "provider": testConfig.Provider, + "model": testConfig.ChatModel, + "eager_input_streaming": true, + }, + } + + responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + return client.ChatCompletionStreamRequest(bfCtx, request) + }) + + RequireNoError(t, err, "Eager input streaming request failed") + if responseChannel == nil { + t.Fatal("Response channel should not be nil") + } + + accumulator := NewStreamingToolCallAccumulator() + var responseCount int + var sawAny bool + + t.Logf("🔧 Testing eager input streaming (fine-grained-tool-streaming-2025-05-14)...") + + for response := range responseChannel { + if response == nil || response.BifrostChatResponse == nil { + continue + } + responseCount++ + sawAny = true + + if response.BifrostChatResponse.Choices != nil { + for i, choice := range response.BifrostChatResponse.Choices { + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + delta := choice.ChatStreamResponseChoice.Delta + for _, tc := range delta.ToolCalls { + accumulator.AccumulateChatToolCall(i, tc) + } + } + } + } + } + + if !sawAny { + t.Fatal("Expected at least one streaming response chunk") + } + t.Logf("Received %d chunks", responseCount) + + // Validate the accumulated tool call is well-formed. If the + // fine-grained-tool-streaming beta header weren't sent (or the + // provider rejected it), the upstream would have returned a 400 + // before any tool_use blocks were emitted. + toolCalls := accumulator.GetFinalChatToolCalls() + if len(toolCalls) == 0 { + t.Error("Expected at least one tool call in stream") + } + for _, tc := range toolCalls { + if tc.Name == "" { + t.Error("Tool call missing function name") + } + if tc.Arguments == "" { + t.Error("Tool call missing arguments JSON") + } + } + + t.Logf("EagerInputStreaming passed: %d tool calls accumulated", len(toolCalls)) + }) +} diff --git a/core/internal/llmtests/image_edit.go b/core/internal/llmtests/image_edit.go index 56ad66d502..deed0bd820 100644 --- a/core/internal/llmtests/image_edit.go +++ b/core/internal/llmtests/image_edit.go @@ -364,8 +364,8 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context t.Error("❌ ExtraFields.Provider is empty") } - if imageEditResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageEditResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageEditRequest @@ -374,7 +374,7 @@ func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.ModelRequested, len(imageEditResponse.Data)) + imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data)) }) } diff --git a/core/internal/llmtests/image_generation.go b/core/internal/llmtests/image_generation.go index 81a0626978..1516ff0088 100644 --- a/core/internal/llmtests/image_generation.go +++ b/core/internal/llmtests/image_generation.go @@ -145,12 +145,12 @@ func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Error("❌ ExtraFields.Provider is empty") } - if imageGenerationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.ModelRequested, len(imageGenerationResponse.Data)) + imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data)) }) } diff --git a/core/internal/llmtests/image_variation.go b/core/internal/llmtests/image_variation.go index 0aca33a63f..d0c4d18e78 100644 --- a/core/internal/llmtests/image_variation.go +++ b/core/internal/llmtests/image_variation.go @@ -162,8 +162,8 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co t.Error("❌ ExtraFields.Provider is empty") } - if imageVariationResponse.ExtraFields.ModelRequested == "" { - t.Error("❌ ExtraFields.ModelRequested is empty") + if imageVariationResponse.ExtraFields.OriginalModelRequested == "" { + t.Error("❌ ExtraFields.OriginalModelRequested is empty") } // Validate RequestType is ImageVariationRequest @@ -172,7 +172,7 @@ func RunImageVariationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Logf("✅ Image variation successful: ID=%s, Provider=%s, Model=%s, Images=%d", - imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.ModelRequested, len(imageVariationResponse.Data)) + imageVariationResponse.ID, imageVariationResponse.ExtraFields.Provider, imageVariationResponse.ExtraFields.OriginalModelRequested, len(imageVariationResponse.Data)) }) } diff --git a/core/internal/llmtests/list_models.go b/core/internal/llmtests/list_models.go index f0b5bcf7c5..3c2133aef5 100644 --- a/core/internal/llmtests/list_models.go +++ b/core/internal/llmtests/list_models.go @@ -9,6 +9,17 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// listModelsBifrostContext returns a context for ListModels. For Replicate, sets BifrostContextKeyDirectKey +// so only the deployments key is used (see replicateProviderTestKeys in account.go). That key must not use an +// empty Models allowlist, or ListModelsPipeline.ShouldEarlyExit returns no models before the API runs. +func listModelsBifrostContext(parent context.Context, provider schemas.ModelProvider) *schemas.BifrostContext { + bfCtx := schemas.NewBifrostContext(parent, schemas.NoDeadline) + if provider == schemas.Replicate { + bfCtx.SetValue(schemas.BifrostContextKeyDirectKey, ReplicateDirectKeyForListModels()) + } + return bfCtx +} + // RunListModelsTest executes the list models test scenario func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { if !testConfig.Scenarios.ListModels { @@ -59,7 +70,7 @@ func RunListModelsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModels", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, request) }) @@ -154,7 +165,7 @@ func RunListModelsResponseMarshalTest(t *testing.T, client *bifrost.Bifrost, ctx } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsResponseMarshal", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, request) }) @@ -293,7 +304,7 @@ func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx cont } response, bifrostErr := WithListModelsTestRetry(t, listModelsRetryConfig, retryContext, expectations, "ListModelsPagination", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, request) }) @@ -336,7 +347,7 @@ func RunListModelsPaginationTest(t *testing.T, client *bifrost.Bifrost, ctx cont } nextPageResponse, nextPageErr := WithListModelsTestRetry(t, listModelsRetryConfig, nextPageRetryContext, expectations, "ListModelsPagination_NextPage", func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + bfCtx := listModelsBifrostContext(ctx, testConfig.Provider) return client.ListModelsRequest(bfCtx, nextPageRequest) }) diff --git a/core/internal/llmtests/provider_feature_support_test.go b/core/internal/llmtests/provider_feature_support_test.go index 6e4738c282..539f4049c0 100644 --- a/core/internal/llmtests/provider_feature_support_test.go +++ b/core/internal/llmtests/provider_feature_support_test.go @@ -654,6 +654,77 @@ func TestProviderBetaHeaderInjection(t *testing.T) { }, expectHeaders: []string{"computer-use-2025-01-24"}, }, + + // ── Fine-grained tool streaming header (eager_input_streaming) ── + // Per cited citations (A overview table + B-header): EagerInputStreaming + // is supported on Anthropic, Bedrock, Vertex, and Azure — all four + // should auto-inject fine-grained-tool-streaming-2025-05-14 when a + // tool has eager_input_streaming: true. + { + name: "Anthropic/eager_input_streaming_header_added", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + eager := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "t1", EagerInputStreaming: &eager}}, + } + }, + expectHeaders: []string{"fine-grained-tool-streaming-2025-05-14"}, + }, + { + name: "Bedrock/eager_input_streaming_header_added", + provider: schemas.Bedrock, + setupReq: func() *anthropic.AnthropicMessageRequest { + eager := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "t1", EagerInputStreaming: &eager}}, + } + }, + expectHeaders: []string{"fine-grained-tool-streaming-2025-05-14"}, + }, + { + name: "Vertex/eager_input_streaming_header_added", + provider: schemas.Vertex, + setupReq: func() *anthropic.AnthropicMessageRequest { + eager := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "t1", EagerInputStreaming: &eager}}, + } + }, + expectHeaders: []string{"fine-grained-tool-streaming-2025-05-14"}, + }, + { + name: "Azure/eager_input_streaming_header_added", + provider: schemas.Azure, + setupReq: func() *anthropic.AnthropicMessageRequest { + eager := true + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "t1", EagerInputStreaming: &eager}}, + } + }, + expectHeaders: []string{"fine-grained-tool-streaming-2025-05-14"}, + }, + { + name: "eager_input_streaming_header_skipped_when_flag_false", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + eager := false + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "t1", EagerInputStreaming: &eager}}, + } + }, + unexpectHeaders: []string{"fine-grained-tool-streaming-2025-05-14"}, + }, + { + name: "eager_input_streaming_header_skipped_when_unset", + provider: schemas.Anthropic, + setupReq: func() *anthropic.AnthropicMessageRequest { + return &anthropic.AnthropicMessageRequest{ + Tools: []anthropic.AnthropicTool{{Name: "t1"}}, + } + }, + unexpectHeaders: []string{"fine-grained-tool-streaming-2025-05-14"}, + }, } for _, tt := range tests { diff --git a/core/internal/llmtests/realtime.go b/core/internal/llmtests/realtime.go index 821aeba9eb..400f5f9cda 100644 --- a/core/internal/llmtests/realtime.go +++ b/core/internal/llmtests/realtime.go @@ -43,7 +43,7 @@ func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.RealtimeModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/llmtests/response_validation.go b/core/internal/llmtests/response_validation.go index bcc419abd3..bc75dd07df 100644 --- a/core/internal/llmtests/response_validation.go +++ b/core/internal/llmtests/response_validation.go @@ -859,7 +859,7 @@ func validateResponsesBasicStructure(response *schemas.BifrostResponsesResponse, } provider := response.ExtraFields.Provider - model := response.ExtraFields.ModelDeployment + model := response.ExtraFields.ResolvedModelUsed // Verify top level status is present for OpenAI and Azure with non-Claude models if provider != "" && (provider == schemas.OpenAI || provider == schemas.Azure) && !strings.Contains(strings.ToLower(model), "claude") { @@ -988,8 +988,7 @@ func validateResponsesTechnicalFields(t *testing.T, response *schemas.BifrostRes // Check model field if expectations.ShouldHaveModel { - if strings.TrimSpace(response.Model) == "" && - strings.TrimSpace(response.ExtraFields.ModelDeployment) == "" { + if strings.TrimSpace(response.Model) == "" { result.Passed = false result.Errors = append(result.Errors, fmt.Sprintf("Expected model field but not present or empty (provider: %s)", response.ExtraFields.Provider)) } diff --git a/core/internal/llmtests/server_tools_via_openai.go b/core/internal/llmtests/server_tools_via_openai.go new file mode 100644 index 0000000000..c5ee1d2000 --- /dev/null +++ b/core/internal/llmtests/server_tools_via_openai.go @@ -0,0 +1,152 @@ +package llmtests + +import ( + "context" + "os" + "strings" + "testing" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +// RunServerToolsViaOpenAIEndpointTest reproduces the user-reported bug where +// sending an Anthropic-server-tool-shaped entry in tools[] via the OpenAI- +// compatible chat-completions endpoint was silently dropped (Claude responded +// with a prose "I can't check real-time data" fallback). The fix was a +// combination of: +// - ChatTool schema gaining Name + all server-tool variant fields. +// - ToAnthropicChatRequest learning to convert non-function tools (server +// tools) into AnthropicTool with the correct variant embed. +// +// This test sends the exact curl-reported shape via BifrostChatRequest + +// ChatCompletionRequest and asserts the request succeeds end-to-end against +// the provider. It covers three server tools that have single-turn triggers +// (web_search, web_fetch, code_execution) across all supporting providers per +// Table 20. Other variants (bash, memory, text_editor, tool_search, +// mcp_toolset, computer_use) require multi-turn tool loops or infra setup +// and are covered by the schema / unit-level round-trip tests instead. +func RunServerToolsViaOpenAIEndpointTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) { + if !testConfig.Scenarios.ServerToolsViaOpenAIEndpoint { + t.Logf("ServerToolsViaOpenAIEndpoint not supported for provider %s", testConfig.Provider) + return + } + + cases := []struct { + name string + toolType schemas.ChatToolType + toolName string + prompt string + // extra lets the case set server-tool metadata (max_uses etc.). + extra func(*schemas.ChatTool) + // supported reports whether this tool is supported on the given + // provider per Table 20 (cited provider feature matrix). + supported func(schemas.ModelProvider) bool + }{ + { + name: "web_search", + toolType: "web_search_20260209", + toolName: "web_search", + prompt: "What is the weather in San Francisco today? Use the web_search tool.", + extra: func(t *schemas.ChatTool) { + five := 5 + t.MaxUses = &five + t.AllowedCallers = []string{"direct"} + }, + // web_search: Anthropic + Vertex + Azure per Table 20 (not Bedrock). + supported: func(p schemas.ModelProvider) bool { + return p == schemas.Anthropic || p == schemas.Vertex || p == schemas.Azure + }, + }, + { + name: "web_fetch", + toolType: "web_fetch_20260309", + toolName: "web_fetch", + prompt: "Fetch https://example.com and summarise the title.", + extra: func(t *schemas.ChatTool) { + three := 3 + t.MaxUses = &three + }, + // web_fetch: Anthropic + Azure only per Table 20. + supported: func(p schemas.ModelProvider) bool { + return p == schemas.Anthropic || p == schemas.Azure + }, + }, + { + name: "code_execution", + toolType: "code_execution_20250825", + toolName: "code_execution", + prompt: "Compute 2^64 minus 1 using the code_execution tool and return the result.", + // code_execution: Anthropic + Azure only per Table 20. + supported: func(p schemas.ModelProvider) bool { + return p == schemas.Anthropic || p == schemas.Azure + }, + }, + } + + t.Run("ServerToolsViaOpenAIEndpoint", func(t *testing.T) { + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + if !tc.supported(testConfig.Provider) { + t.Skipf("%s not supported on %s per Table 20", tc.name, testConfig.Provider) + } + if os.Getenv("SKIP_PARALLEL_TESTS") != "true" { + t.Parallel() + } + + tool := schemas.ChatTool{ + Type: tc.toolType, + Name: tc.toolName, + } + if tc.extra != nil { + tc.extra(&tool) + } + + req := &schemas.BifrostChatRequest{ + Provider: testConfig.Provider, + Model: testConfig.ChatModel, + Input: []schemas.ChatMessage{ + CreateBasicChatMessage(tc.prompt), + }, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(500), + Tools: []schemas.ChatTool{tool}, + }, + Fallbacks: testConfig.Fallbacks, + } + + bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) + resp, err := client.ChatCompletionRequest(bfCtx, req) + if err != nil { + t.Fatalf("%s tool request failed: %s", tc.name, GetErrorMessage(err)) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + + // Regression signals: + // 1. Upstream accepted the request (no error). + // 2. Response is not the prose fallback Claude emits when + // the server-tool was silently stripped pre-fix + // ("I can't/cannot/don't have access to real-time ..."). + // The schema + conversion unit tests prove the outbound + // request carries the tool; this live test proves the + // provider accepts the shape AND actually uses the tool + // rather than answering from parametric memory. + content := GetChatContent(resp) + lc := strings.ToLower(content) + if strings.Contains(lc, "can't access real-time") || + strings.Contains(lc, "cannot access real-time") || + strings.Contains(lc, "don't have access to real-time") { + t.Fatalf("%s regression: tool appears to be ignored, content=%q", tc.name, content) + } + t.Logf("%s tool live call succeeded: chars=%d", tc.name, len(content)) + }) + } + }) +} diff --git a/core/internal/llmtests/speech_synthesis.go b/core/internal/llmtests/speech_synthesis.go index 4e08d6e2c8..aae66423a3 100644 --- a/core/internal/llmtests/speech_synthesis.go +++ b/core/internal/llmtests/speech_synthesis.go @@ -239,8 +239,8 @@ func RunSpeechSynthesisAdvancedTest(t *testing.T, client *bifrost.Bifrost, ctx c t.Fatalf("HD audio data too small: got %d bytes, expected at least 5000", audioSize) } - if speechResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.ModelRequested) + if speechResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Expected HD model, got: %s", speechResponse.ExtraFields.OriginalModelRequested) } t.Logf("✅ HD speech synthesis successful: %d bytes generated", len(speechResponse.Audio)) @@ -344,8 +344,8 @@ func validateSpeechSynthesisSpecific(t *testing.T, response *schemas.BifrostSpee t.Fatalf("Audio data too small: got %d bytes, expected at least %d", audioSize, expectMinBytes) } - if expectedModel != "" && response.ExtraFields.ModelRequested != expectedModel { - t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.ModelRequested) + if expectedModel != "" && response.ExtraFields.OriginalModelRequested != expectedModel { + t.Logf("⚠️ Expected model, got: %s", response.ExtraFields.OriginalModelRequested) } t.Logf("✅ Audio validation passed: %d bytes generated", audioSize) diff --git a/core/internal/llmtests/speech_synthesis_stream.go b/core/internal/llmtests/speech_synthesis_stream.go index 87268f3c17..8b7bdc8efb 100644 --- a/core/internal/llmtests/speech_synthesis_stream.go +++ b/core/internal/llmtests/speech_synthesis_stream.go @@ -184,8 +184,8 @@ func RunSpeechSynthesisStreamTest(t *testing.T, client *bifrost.Bifrost, ctx con if response.BifrostSpeechStreamResponse.Type != "" && (response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDelta && response.BifrostSpeechStreamResponse.Type != schemas.SpeechStreamResponseTypeDone) { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostSpeechStreamResponse.Type) } - if response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } @@ -348,8 +348,8 @@ func RunSpeechSynthesisStreamAdvancedTest(t *testing.T, client *bifrost.Bifrost, t.Logf("✅ HD chunk %d: %d bytes", chunkCount, chunkSize) } - if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested != testConfig.SpeechSynthesisModel { - t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.ModelRequested) + if response.BifrostSpeechStreamResponse != nil && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != "" && response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested != testConfig.SpeechSynthesisModel { + t.Logf("⚠️ Unexpected HD model: %s", response.BifrostSpeechStreamResponse.ExtraFields.OriginalModelRequested) } } diff --git a/core/internal/llmtests/tests.go b/core/internal/llmtests/tests.go index af3006b9a1..108894feb4 100644 --- a/core/internal/llmtests/tests.go +++ b/core/internal/llmtests/tests.go @@ -120,6 +120,8 @@ func RunAllComprehensiveTests(t *testing.T, client *bifrost.Bifrost, ctx context RunCompactionTest, RunInterleavedThinkingTest, RunFastModeTest, + RunEagerInputStreamingTest, + RunServerToolsViaOpenAIEndpointTest, } // Execute all test scenarios without raw request/response (default behavior) @@ -239,6 +241,8 @@ func printTestSummary(t *testing.T, testConfig ComprehensiveTestConfig) { {"Compaction", testConfig.Scenarios.Compaction}, {"InterleavedThinking", testConfig.Scenarios.InterleavedThinking}, {"FastMode", testConfig.Scenarios.FastMode}, + {"EagerInputStreaming", testConfig.Scenarios.EagerInputStreaming}, + {"ServerToolsViaOpenAIEndpoint", testConfig.Scenarios.ServerToolsViaOpenAIEndpoint}, } supported := 0 diff --git a/core/internal/llmtests/transcription_stream.go b/core/internal/llmtests/transcription_stream.go index dfc80fc533..a28239c00f 100644 --- a/core/internal/llmtests/transcription_stream.go +++ b/core/internal/llmtests/transcription_stream.go @@ -242,8 +242,12 @@ func RunTranscriptionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx conte if response.BifrostTranscriptionStreamResponse.Type != schemas.TranscriptionStreamResponseTypeDelta { t.Logf("⚠️ Unexpected object type in stream: %s", response.BifrostTranscriptionStreamResponse.Type) } - if response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != "" && response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested != testConfig.TranscriptionModel { - t.Logf("⚠️ Unexpected model in stream: %s", response.BifrostTranscriptionStreamResponse.ExtraFields.ModelRequested) + gotModel := response.BifrostTranscriptionStreamResponse.ExtraFields.OriginalModelRequested + if gotModel == "" { + t.Fatal("❌ Stream chunk missing extra_fields.original_model_requested") + } + if gotModel != testConfig.TranscriptionModel { + t.Fatalf("❌ Unexpected original_model_requested in stream: got %q want %q", gotModel, testConfig.TranscriptionModel) } lastResponse = DeepCopyBifrostStreamChunk(response) diff --git a/core/internal/llmtests/video.go b/core/internal/llmtests/video.go index c622edf6b4..8ac2d6e396 100644 --- a/core/internal/llmtests/video.go +++ b/core/internal/llmtests/video.go @@ -48,8 +48,8 @@ func RunVideoGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.C if resp.ExtraFields.Provider == "" { t.Fatal("❌ Video generation extra_fields.provider is empty") } - if resp.ExtraFields.ModelRequested == "" { - t.Fatal("❌ Video generation extra_fields.model_requested is empty") + if resp.ExtraFields.OriginalModelRequested == "" { + t.Fatal("❌ Video generation extra_fields.original_model_requested is empty") } t.Logf("✅ Video generation created job: id=%s status=%s", resp.ID, resp.Status) diff --git a/core/internal/llmtests/websocket_responses.go b/core/internal/llmtests/websocket_responses.go index 420a049fb7..966463dade 100644 --- a/core/internal/llmtests/websocket_responses.go +++ b/core/internal/llmtests/websocket_responses.go @@ -38,7 +38,7 @@ func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx contex bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.ChatModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/mcptests/agent_test_helpers.go b/core/internal/mcptests/agent_test_helpers.go index d19d953ca0..85512dcce6 100644 --- a/core/internal/mcptests/agent_test_helpers.go +++ b/core/internal/mcptests/agent_test_helpers.go @@ -131,11 +131,11 @@ func SetupAgentTest(t *testing.T, config AgentTestConfig) (*mcp.MCPManager, *Dyn // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -192,11 +192,11 @@ func SetupAgentTestWithClients(t *testing.T, config AgentTestConfig, customClien // Create context with filtering baseCtx := context.Background() - if len(config.ClientFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, config.ClientFiltering) + if config.ClientFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, config.ClientFiltering) } - if len(config.ToolFiltering) > 0 { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, config.ToolFiltering) + if config.ToolFiltering != nil { + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, config.ToolFiltering) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) diff --git a/core/internal/mcptests/annotations_test.go b/core/internal/mcptests/annotations_test.go new file mode 100644 index 0000000000..e85b54a79f --- /dev/null +++ b/core/internal/mcptests/annotations_test.go @@ -0,0 +1,220 @@ +package mcptests + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================================= +// MCP ANNOTATION TESTS +// +// These tests verify two invariants of the MCP annotations feature: +// +// 1. PRESERVATION: annotations attached to a registered tool survive the full +// MCP→Bifrost conversion and remain accessible on ChatTool.Annotations +// after retrieval from the manager. +// +// 2. ISOLATION: annotations are tagged json:"-" on ChatTool, so they are never +// included in the JSON body forwarded to LLM providers. +// ============================================================================= + +// TestAnnotations_PreservedAfterToolRegistration verifies that annotations set +// on an InProcess ChatTool schema are stored in the tool map without modification. +func TestAnnotations_PreservedAfterToolRegistration(t *testing.T) { + t.Parallel() + + readOnly := true + idempotent := true + + manager := setupMCPManager(t) + + toolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "read_resource", + Description: schemas.Ptr("Reads a resource"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: schemas.NewOrderedMapFromPairs( + schemas.KV("uri", map[string]interface{}{ + "type": "string", + "description": "URI of the resource to read", + }), + ), + Required: []string{"uri"}, + }, + }, + Annotations: &schemas.MCPToolAnnotations{ + Title: "Resource Reader", + ReadOnlyHint: &readOnly, + IdempotentHint: &idempotent, + }, + } + + err := manager.RegisterTool( + "read_resource", + "Reads a resource", + func(args any) (string, error) { return `{"ok":true}`, nil }, + toolSchema, + ) + require.NoError(t, err) + + ctx := createTestContext() + toolPerClient := manager.GetToolPerClient(ctx) + + var found *schemas.ChatTool +outer1: + for _, tools := range toolPerClient { + for i := range tools { + if tools[i].Function != nil && strings.HasSuffix(tools[i].Function.Name, "-read_resource") { + cp := tools[i] + found = &cp + break outer1 + } + } + } + require.NotNil(t, found, "read_resource tool should be present in the tool map") + + // Annotations must be preserved on ChatTool (not lost after registration) + require.NotNil(t, found.Annotations, "Annotations should be preserved on ChatTool") + assert.Equal(t, "Resource Reader", found.Annotations.Title) + require.NotNil(t, found.Annotations.ReadOnlyHint) + assert.True(t, *found.Annotations.ReadOnlyHint) + require.NotNil(t, found.Annotations.IdempotentHint) + assert.True(t, *found.Annotations.IdempotentHint) + assert.Nil(t, found.Annotations.DestructiveHint) + assert.Nil(t, found.Annotations.OpenWorldHint) +} + +// TestAnnotations_AbsentFromProviderJSON verifies that annotations do NOT appear +// in the JSON representation of a tool — i.e. the payload that would be forwarded +// to an LLM provider. +func TestAnnotations_AbsentFromProviderJSON(t *testing.T) { + t.Parallel() + + readOnly := true + destructive := false + + manager := setupMCPManager(t) + + toolSchema := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "write_file", + Description: schemas.Ptr("Writes content to a file"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: schemas.NewOrderedMapFromPairs( + schemas.KV("path", map[string]interface{}{ + "type": "string", + "description": "Destination file path", + }), + schemas.KV("content", map[string]interface{}{ + "type": "string", + "description": "Content to write", + }), + ), + Required: []string{"path", "content"}, + }, + }, + Annotations: &schemas.MCPToolAnnotations{ + Title: "File Writer", + ReadOnlyHint: &readOnly, + DestructiveHint: &destructive, + }, + } + + err := manager.RegisterTool( + "write_file", + "Writes content to a file", + func(args any) (string, error) { return `{"ok":true}`, nil }, + toolSchema, + ) + require.NoError(t, err) + + ctx := createTestContext() + toolPerClient := manager.GetToolPerClient(ctx) + + var found *schemas.ChatTool +outer2: + for _, tools := range toolPerClient { + for i := range tools { + if tools[i].Function != nil && strings.HasSuffix(tools[i].Function.Name, "-write_file") { + cp := tools[i] + found = &cp + break outer2 + } + } + } + require.NotNil(t, found, "write_file tool should be present in the tool map") + + // The tool must have annotations in memory + require.NotNil(t, found.Annotations, "Annotations must be in memory for downstream use") + + // Serialize the tool as a provider would receive it + toolJSON, err := json.Marshal(found) + require.NoError(t, err) + s := string(toolJSON) + + // None of the annotation data must leak into the JSON. + // Use the key token `"annotations":` to avoid false positives from description text. + assert.NotContains(t, s, `"annotations":`, "annotations key must be absent from provider JSON") + assert.NotContains(t, s, "readOnlyHint", "readOnlyHint must be absent from provider JSON") + assert.NotContains(t, s, "destructiveHint", "destructiveHint must be absent from provider JSON") + assert.NotContains(t, s, "File Writer", "annotation title must be absent from provider JSON") + + // The function definition itself must still be present + assert.Contains(t, s, "write_file", "function name must be present in provider JSON") + assert.Contains(t, s, "path", "parameter must be present in provider JSON") +} + +// TestAnnotations_DeepCopyPreservesAnnotations verifies that the deep-copy path +// (used during plugin accumulation and streaming) correctly copies annotations. +func TestAnnotations_DeepCopyPreservesAnnotations(t *testing.T) { + t.Parallel() + + readOnly := true + + original := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "read_config", + Description: schemas.Ptr("Reads configuration from disk"), + }, + Annotations: &schemas.MCPToolAnnotations{ + Title: "Config Reader", + ReadOnlyHint: &readOnly, + }, + } + + copied := schemas.DeepCopyChatTool(original) + + // Annotations must survive the deep copy + require.NotNil(t, copied.Annotations, "Annotations must be preserved after deep copy") + assert.Equal(t, "Config Reader", copied.Annotations.Title) + require.NotNil(t, copied.Annotations.ReadOnlyHint) + assert.True(t, *copied.Annotations.ReadOnlyHint) + + // Mutate via the pointed-to value to detect pointer aliasing + *original.Annotations.ReadOnlyHint = false + assert.NotSame(t, original.Annotations.ReadOnlyHint, copied.Annotations.ReadOnlyHint, + "deep copy must not share the ReadOnlyHint pointer with the original") + assert.True(t, *copied.Annotations.ReadOnlyHint, + "mutating original's ReadOnlyHint must not affect the deep copy") + + // JSON of the copy must also be annotation-free (same guarantee as the original) + toolJSON, err := json.Marshal(copied) + require.NoError(t, err) + s := string(toolJSON) + // Check for the JSON key pattern, not just the substring, to avoid false positives + // from description text. The key would appear as `"annotations":` in JSON. + assert.NotContains(t, s, `"annotations":`, + "annotations key must be absent from provider JSON even after deep copy") + assert.NotContains(t, s, "readOnlyHint", + "readOnlyHint must be absent from provider JSON even after deep copy") +} diff --git a/core/internal/mcptests/codemode_stdio_test.go b/core/internal/mcptests/codemode_stdio_test.go index 8fe5841a82..aab3a15172 100644 --- a/core/internal/mcptests/codemode_stdio_test.go +++ b/core/internal/mcptests/codemode_stdio_test.go @@ -56,27 +56,27 @@ func setupCodeModeWithSTDIOServers(t *testing.T, serverNames ...string) (*mcp.MC config = GetTemperatureMCPClientConfig(bifrostRoot) config.IsCodeModeClient = true config.ID = "temperature-client" // Match test expectations - config.Name = "temperature" // Use lowercase to match test code + config.Name = "temperature" // Use lowercase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "go-test-server": config = GetGoTestServerConfig(bifrostRoot) config.ID = "goTestServer-client" // Match test expectations - config.Name = "goTestServer" // Use camelCase to match test code + config.Name = "goTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "edge-case-server": config = GetEdgeCaseServerConfig(bifrostRoot) config.ID = "edgeCaseServer-client" // Match test expectations - config.Name = "edgeCaseServer" // Use camelCase to match test code + config.Name = "edgeCaseServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "error-test-server": config = GetErrorTestServerConfig(bifrostRoot) config.ID = "errorTestServer-client" // Match test expectations - config.Name = "errorTestServer" // Use camelCase to match test code + config.Name = "errorTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "parallel-test-server": config = GetParallelTestServerConfig(bifrostRoot) config.ID = "parallelTestServer-client" // Match test expectations - config.Name = "parallelTestServer" // Use camelCase to match test code + config.Name = "parallelTestServer" // Use camelCase to match test code config.ToolsToAutoExecute = []string{"executeToolCode", "listToolFiles", "readToolFile"} case "test-tools-server": // test-tools-server doesn't have a fixture, set up manually @@ -367,9 +367,9 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { expectedError string }{ { - name: "allow_only_test_tools_server", - includeClients: []string{"testToolsServer"}, - code: `result = testToolsServer.echo(message="allowed")`, + name: "allow_only_test_tools_server", + includeClients: []string{"testToolsServer"}, + code: `result = testToolsServer.echo(message="allowed")`, shouldSucceed: true, expectedInResult: "allowed", }, @@ -377,13 +377,13 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_test_tools_server", includeClients: []string{"temperature"}, code: `result = testToolsServer.echo(message="blocked")`, - shouldSucceed: false, - expectedError: "undefined: testToolsServer", + shouldSucceed: false, + expectedError: "undefined: testToolsServer", }, { - name: "allow_only_temperature_server", - includeClients: []string{"temperature"}, - code: `result = temperature.get_temperature(location="Paris")`, + name: "allow_only_temperature_server", + includeClients: []string{"temperature"}, + code: `result = temperature.get_temperature(location="Paris")`, shouldSucceed: true, expectedInResult: "Paris", }, @@ -391,8 +391,8 @@ func TestCodeMode_STDIO_ServerFiltering(t *testing.T) { name: "block_temperature_server", includeClients: []string{"testToolsServer"}, code: `result = temperature.get_temperature(location="blocked")`, - shouldSucceed: false, - expectedError: "undefined: temperature", + shouldSucceed: false, + expectedError: "undefined: temperature", }, { name: "allow_both_servers", @@ -409,7 +409,7 @@ result = {"echo": echo, "temp": temp}`, t.Run(tc.name, func(t *testing.T) { // Create context with client filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied at tool listing level @@ -524,7 +524,7 @@ result = {"echo": echo, "calc": calc}`, t.Run(tc.name, func(t *testing.T) { // Create context with tool filtering baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) // Verify filtering is applied @@ -622,10 +622,10 @@ result = {"echo": echo, "temp": temp}`, // Create context with both client and tool filtering baseCtx := context.Background() if tc.includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) } if tc.includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, tc.includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, tc.includeTools) } ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) @@ -1692,7 +1692,7 @@ result = {"count": 3}`, for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, tc.includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, tc.includeClients) ctx := schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) toolCall := CreateExecuteToolCodeCall(fmt.Sprintf("call-%s", tc.name), tc.code) diff --git a/core/internal/mcptests/concurrency_advanced_test.go b/core/internal/mcptests/concurrency_advanced_test.go index a1c3823831..e3c5793df4 100644 --- a/core/internal/mcptests/concurrency_advanced_test.go +++ b/core/internal/mcptests/concurrency_advanced_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -533,14 +532,14 @@ func TestConcurrent_FilteringChanges(t *testing.T) { if id%2 == 0 { // Even: allow all tools baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-*"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } else { // Odd: allow only echo baseCtx := context.Background() - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, []string{"*"}) - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, []string{"*"}) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, []string{"bifrostInternal-echo"}) ctx = schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/fixtures.go b/core/internal/mcptests/fixtures.go index 88b00a9f70..f760ae5ac0 100644 --- a/core/internal/mcptests/fixtures.go +++ b/core/internal/mcptests/fixtures.go @@ -1422,7 +1422,7 @@ func (a *testAccount) GetKeysForProvider(ctx context.Context, providerKey schema return []schemas.Key{ { Value: *schemas.NewEnvVar(apiKey), - Models: []string{}, // Empty means all models + Models: schemas.WhiteList{"*"}, Weight: 1.0, }, }, nil @@ -1460,6 +1460,17 @@ func setupBifrost(t *testing.T) *bifrost.Bifrost { return bifrostInstance } +// noopPluginPipeline is a passthrough pipeline used in tests that don't need plugin hooks. +type noopPluginPipeline struct{} + +func (n *noopPluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) { + return req, nil, 0 +} + +func (n *noopPluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { + return mcpResp, bifrostErr +} + // setupMCPManager creates an MCP manager for testing func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mcp.MCPManager { t.Helper() @@ -1472,9 +1483,14 @@ func setupMCPManager(t *testing.T, clientConfigs ...schemas.MCPClientConfig) *mc clientConfigPtrs[i] = &clientConfigs[i] } - // Create MCP config + // Create MCP config with a no-op plugin pipeline so that codemode tool calls + // work correctly even when no Bifrost instance is attached. mcpConfig := &schemas.MCPConfig{ ClientConfigs: clientConfigPtrs, + PluginPipelineProvider: func() interface{} { + return &noopPluginPipeline{} + }, + ReleasePluginPipeline: func(pipeline interface{}) {}, } // Create Starlark CodeMode @@ -1984,10 +2000,10 @@ func AssertExecutionTimeUnder(t *testing.T, fn func(), maxDuration time.Duration func CreateTestContextWithMCPFilter(includeClients []string, includeTools []string) *schemas.BifrostContext { baseCtx := context.Background() if includeClients != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeClients, includeClients) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeClients, includeClients) } if includeTools != nil { - baseCtx = context.WithValue(baseCtx, mcp.MCPContextKeyIncludeTools, includeTools) + baseCtx = context.WithValue(baseCtx, schemas.MCPContextKeyIncludeTools, includeTools) } return schemas.NewBifrostContext(baseCtx, schemas.NoDeadline) } diff --git a/core/internal/mcptests/tool_filtering_test.go b/core/internal/mcptests/tool_filtering_test.go index eb8b370a28..15fde03d75 100644 --- a/core/internal/mcptests/tool_filtering_test.go +++ b/core/internal/mcptests/tool_filtering_test.go @@ -160,7 +160,7 @@ func TestToolsToExecute_ExplicitList(t *testing.T) { // Verify configuration was set correctly clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) } func TestToolsToExecute_SingleTool(t *testing.T) { @@ -178,10 +178,10 @@ func TestToolsToExecute_SingleTool(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) // Verify it's not allow-all - assert.NotEqual(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute, "should not be wildcard") + assert.NotEqual(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute, "should not be wildcard") } // ============================================================================= @@ -204,8 +204,8 @@ func TestToolsToAutoExecute_Basic(t *testing.T) { // Verify the client was created with correct configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToAutoExecute) } func TestToolsToAutoExecute_NotInExecuteList(t *testing.T) { @@ -224,8 +224,8 @@ func TestToolsToAutoExecute_NotInExecuteList(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.Equal(t, []string{"hash"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"hash"}, clients[0].ExecutionConfig.ToolsToAutoExecute) assert.NotEqual(t, clients[0].ExecutionConfig.ToolsToExecute, clients[0].ExecutionConfig.ToolsToAutoExecute) } @@ -245,7 +245,7 @@ func TestToolsToAutoExecute_Wildcard(t *testing.T) { // Verify configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToAutoExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToAutoExecute) } // ============================================================================= @@ -267,7 +267,7 @@ func TestContextFilteringRestrictsWildcard(t *testing.T) { // Verify client configuration allows all clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"*"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"*"}, clients[0].ExecutionConfig.ToolsToExecute) // Context restricts to only specific tools (verify context works separately) ctx := CreateTestContextWithMCPFilter(nil, []string{"encode"}) @@ -305,9 +305,9 @@ func TestFilteringMultipleClients_DifferentRules(t *testing.T) { // Find and verify each client for _, client := range clients { if client.ExecutionConfig.ID == "stdio-client-1" { - assert.Equal(t, []string{"encode"}, client.ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, client.ExecutionConfig.ToolsToExecute) } else if client.ExecutionConfig.ID == "stdio-client-2" { - assert.Equal(t, []string{"*"}, client.ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"*"}, client.ExecutionConfig.ToolsToExecute) } } } @@ -331,7 +331,7 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Verify initial configuration clients := manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) // Edit client to only allow second tool clientConfig.ToolsToExecute = []string{"hash"} @@ -341,6 +341,6 @@ func TestFilteringChangesAfterClientEdit(t *testing.T) { // Verify configuration changed clients = manager.GetClients() require.Len(t, clients, 1) - assert.Equal(t, []string{"hash"}, clients[0].ExecutionConfig.ToolsToExecute) - assert.NotEqual(t, []string{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.Equal(t, schemas.WhiteList{"hash"}, clients[0].ExecutionConfig.ToolsToExecute) + assert.NotEqual(t, schemas.WhiteList{"encode"}, clients[0].ExecutionConfig.ToolsToExecute) } diff --git a/core/keyselectors/weightedrandom.go b/core/keyselectors/weightedrandom.go new file mode 100644 index 0000000000..a3b8994d0d --- /dev/null +++ b/core/keyselectors/weightedrandom.go @@ -0,0 +1,35 @@ +package keyselectors + +import ( + "math/rand" + + "github.com/maximhq/bifrost/core/schemas" +) + +func WeightedRandom(ctx *schemas.BifrostContext, keys []schemas.Key, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { + // Use a weighted random selection based on key weights + totalWeight := 0 + for _, key := range keys { + totalWeight += int(key.Weight * 100) // Convert float to int for better performance + } + + // If all keys have zero weight, fall back to uniform random selection + if totalWeight == 0 { + return keys[rand.Intn(len(keys))], nil + } + + // Use global thread-safe random (Go 1.20+) - no allocation, no syscall + randomValue := rand.Intn(totalWeight) + + // Select key based on weight + currentWeight := 0 + for _, key := range keys { + currentWeight += int(key.Weight * 100) + if randomValue < currentWeight { + return key, nil + } + } + + // Fallback to first key if something goes wrong + return keys[0], nil +} diff --git a/core/mcp/agent.go b/core/mcp/agent.go index fe4481ad7a..96d16ec24e 100644 --- a/core/mcp/agent.go +++ b/core/mcp/agent.go @@ -1,6 +1,7 @@ package mcp import ( + "errors" "fmt" "strings" "sync" @@ -10,7 +11,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) - type AgentModeExecutor struct { logger schemas.Logger } @@ -40,7 +40,7 @@ func (a *AgentModeExecutor) ExecuteAgentForChatRequest( makeReq func(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError), fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Create adapter for Chat API adapter := &chatAPIAdapter{ @@ -143,7 +143,7 @@ func (a *AgentModeExecutor) executeAgent( adapter agentAPIAdapter, fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, executeToolFunc func(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error), - clientManager ClientManager, + clientManager ClientManager, ) (interface{}, *schemas.BifrostError) { // Get initial response from adapter currentResponse := adapter.getInitialResponse() @@ -157,6 +157,9 @@ func (a *AgentModeExecutor) executeAgent( allExecutedToolResults := make([]*schemas.ChatMessage, 0) allExecutedToolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + // Accumulate token usage across all LLM calls in the agent loop + accumulatedUsage := adapter.extractUsage(currentResponse) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) if ok { ctx.SetValue(schemas.BifrostMCPAgentOriginalRequestID, originalRequestID) @@ -207,14 +210,8 @@ func (a *AgentModeExecutor) executeAgent( continue } - // Step 1: Convert literal \n escape sequences to actual newlines for parsing - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - if len(codeWithNewlines) != len(code) { - a.logger.Debug("%s Converted literal \\n escape sequences to actual newlines", CodeModeLogPrefix) - } - - // Step 2: Extract tool calls from code during AST formation - extractedToolCalls, err := extractToolCallsFromCode(codeWithNewlines) + // Step 1: Extract tool calls from the original source code during validation + extractedToolCalls, err := extractToolCallsFromCode(code) if err != nil { a.logger.Debug("%s Failed to parse code for tool calls: %v", CodeModeLogPrefix, err) nonAutoExecutableTools = append(nonAutoExecutableTools, toolCall) @@ -289,6 +286,8 @@ func (a *AgentModeExecutor) executeAgent( wg := sync.WaitGroup{} wg.Add(len(autoExecutableTools)) channelToolResults := make(chan *schemas.ChatMessage, len(autoExecutableTools)) + var authRequiredErr *schemas.MCPUserOAuthRequiredError + var authRequiredOnce sync.Once for _, toolCall := range autoExecutableTools { go func(toolCall schemas.ChatAssistantMessageToolCall) { defer wg.Done() @@ -305,6 +304,15 @@ func (a *AgentModeExecutor) executeAgent( mcpResponse, toolErr := executeToolFunc(toolCtx, mcpRequest) if toolErr != nil { + // Check if this is a per-user OAuth auth-required error + var oauthErr *schemas.MCPUserOAuthRequiredError + if errors.As(toolErr, &oauthErr) { + authRequiredOnce.Do(func() { + authRequiredErr = oauthErr + }) + channelToolResults <- createToolResultMessage(toolCall, "", toolErr) + return + } a.logger.Warn("Tool execution failed: %v", toolErr) channelToolResults <- createToolResultMessage(toolCall, "", toolErr) } else if mcpResponse != nil && mcpResponse.ChatMessage != nil { @@ -321,6 +329,23 @@ func (a *AgentModeExecutor) executeAgent( wg.Wait() close(channelToolResults) + // If any tool required per-user OAuth, stop the agent loop and return the error + if authRequiredErr != nil { + statusCode := 401 + errType := "mcp_auth_required" + return nil, &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &statusCode, + Error: &schemas.ErrorField{ + Message: authRequiredErr.Message, + Type: &errType, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + MCPAuthRequired: authRequiredErr, + }, + } + } + // Collect tool results executedToolResults = make([]*schemas.ChatMessage, 0, len(autoExecutableTools)) for toolResult := range channelToolResults { @@ -342,6 +367,8 @@ func (a *AgentModeExecutor) executeAgent( if depth == 1 && len(allExecutedToolResults) == 0 { return currentResponse, nil } + // Apply accumulated usage before building the final response + adapter.applyUsage(currentResponse, accumulatedUsage) // Create response with all executed tool results from all iterations, and non-auto-executable tool calls return adapter.createResponseWithExecutedTools(currentResponse, allExecutedToolResults, allExecutedToolCalls, nonAutoExecutableTools), nil } @@ -364,11 +391,127 @@ func (a *AgentModeExecutor) executeAgent( } currentResponse = response + accumulatedUsage = mergeUsage(accumulatedUsage, adapter.extractUsage(currentResponse)) } + adapter.applyUsage(currentResponse, accumulatedUsage) return currentResponse, nil } +// mergeUsage sums token counts and costs from two BifrostLLMUsage values. +// Detail sub-fields are summed when both are present; if only one is non-nil it is kept as-is. +func mergeUsage(base, add *schemas.BifrostLLMUsage) *schemas.BifrostLLMUsage { + if add == nil { + return base + } + if base == nil { + return add + } + + merged := &schemas.BifrostLLMUsage{ + PromptTokens: base.PromptTokens + add.PromptTokens, + CompletionTokens: base.CompletionTokens + add.CompletionTokens, + TotalTokens: base.TotalTokens + add.TotalTokens, + } + + // Merge prompt token details + if base.PromptTokensDetails != nil || add.PromptTokensDetails != nil { + bd := base.PromptTokensDetails + ad := add.PromptTokensDetails + if bd == nil { + bd = &schemas.ChatPromptTokensDetails{} + } + if ad == nil { + ad = &schemas.ChatPromptTokensDetails{} + } + merged.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: bd.TextTokens + ad.TextTokens, + AudioTokens: bd.AudioTokens + ad.AudioTokens, + ImageTokens: bd.ImageTokens + ad.ImageTokens, + CachedReadTokens: bd.CachedReadTokens + ad.CachedReadTokens, + CachedWriteTokens: bd.CachedWriteTokens + ad.CachedWriteTokens, + } + } + + // Merge completion token details + if base.CompletionTokensDetails != nil || add.CompletionTokensDetails != nil { + bd := base.CompletionTokensDetails + ad := add.CompletionTokensDetails + if bd == nil { + bd = &schemas.ChatCompletionTokensDetails{} + } + if ad == nil { + ad = &schemas.ChatCompletionTokensDetails{} + } + merged.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: bd.TextTokens + ad.TextTokens, + AcceptedPredictionTokens: bd.AcceptedPredictionTokens + ad.AcceptedPredictionTokens, + AudioTokens: bd.AudioTokens + ad.AudioTokens, + ReasoningTokens: bd.ReasoningTokens + ad.ReasoningTokens, + RejectedPredictionTokens: bd.RejectedPredictionTokens + ad.RejectedPredictionTokens, + } + if bd.CitationTokens != nil || ad.CitationTokens != nil { + bct := 0 + act := 0 + if bd.CitationTokens != nil { + bct = *bd.CitationTokens + } + if ad.CitationTokens != nil { + act = *ad.CitationTokens + } + sum := bct + act + merged.CompletionTokensDetails.CitationTokens = &sum + } + if bd.NumSearchQueries != nil || ad.NumSearchQueries != nil { + bnsq := 0 + ansq := 0 + if bd.NumSearchQueries != nil { + bnsq = *bd.NumSearchQueries + } + if ad.NumSearchQueries != nil { + ansq = *ad.NumSearchQueries + } + sum := bnsq + ansq + merged.CompletionTokensDetails.NumSearchQueries = &sum + } + if bd.ImageTokens != nil || ad.ImageTokens != nil { + bit := 0 + ait := 0 + if bd.ImageTokens != nil { + bit = *bd.ImageTokens + } + if ad.ImageTokens != nil { + ait = *ad.ImageTokens + } + sum := bit + ait + merged.CompletionTokensDetails.ImageTokens = &sum + } + } + + // Merge cost + if base.Cost != nil || add.Cost != nil { + bc := base.Cost + ac := add.Cost + if bc == nil { + bc = &schemas.BifrostCost{} + } + if ac == nil { + ac = &schemas.BifrostCost{} + } + merged.Cost = &schemas.BifrostCost{ + InputTokensCost: bc.InputTokensCost + ac.InputTokensCost, + OutputTokensCost: bc.OutputTokensCost + ac.OutputTokensCost, + ReasoningTokensCost: bc.ReasoningTokensCost + ac.ReasoningTokensCost, + CitationTokensCost: bc.CitationTokensCost + ac.CitationTokensCost, + SearchQueriesCost: bc.SearchQueriesCost + ac.SearchQueriesCost, + RequestCost: bc.RequestCost + ac.RequestCost, + TotalCost: bc.TotalCost + ac.TotalCost, + } + } + + return merged +} + // extractToolCalls extracts all tool calls from a chat response. // It iterates through all choices in the response and collects tool calls // from assistant messages. @@ -460,25 +603,23 @@ func buildAllowedAutoExecutionTools(ctx *schemas.BifrostContext, clientManager C // Get auto-executable tools from config toolsToAutoExecute := client.ExecutionConfig.ToolsToAutoExecute - if len(toolsToAutoExecute) == 0 { + if toolsToAutoExecute.IsEmpty() { // No auto-executable tools configured for this client continue } // Parse tool names (as they appear in JavaScript code) autoExecutableTools := []string{} - for _, originalToolName := range toolsToAutoExecute { - // Handle wildcard "*" - means all tools are auto-executable - if originalToolName == "*" { - autoExecutableTools = append(autoExecutableTools, "*") - continue + if toolsToAutoExecute.IsUnrestricted() { + autoExecutableTools = append(autoExecutableTools, "*") + } else { + for _, originalToolName := range toolsToAutoExecute { + // Replace - with _ for code mode compatibility, then parse for JS compatibility + toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") + parsedToolName := parseToolName(toolNameForCode) + autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Replace - with _ for code mode compatibility, then parse for JS compatibility - toolNameForCode := strings.ReplaceAll(originalToolName, "-", "_") - parsedToolName := parseToolName(toolNameForCode) - autoExecutableTools = append(autoExecutableTools, parsedToolName) } - // Add to map if there are auto-executable tools if len(autoExecutableTools) > 0 { allowedTools[clientName] = autoExecutableTools diff --git a/core/mcp/agentadaptors.go b/core/mcp/agentadaptors.go index 6986cd9798..7b78df4389 100644 --- a/core/mcp/agentadaptors.go +++ b/core/mcp/agentadaptors.go @@ -59,6 +59,12 @@ type agentAPIAdapter interface { executedToolCalls []schemas.ChatAssistantMessageToolCall, nonAutoExecutableToolCalls []schemas.ChatAssistantMessageToolCall, ) interface{} + + // extractUsage returns the token usage from a response as BifrostLLMUsage. + extractUsage(response interface{}) *schemas.BifrostLLMUsage + + // applyUsage sets accumulated usage on the response in place. + applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) } // chatAPIAdapter implements agentAPIAdapter for Chat API @@ -175,6 +181,14 @@ func (c *chatAPIAdapter) createResponseWithExecutedTools( ) } +func (c *chatAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage { + return response.(*schemas.BifrostChatResponse).Usage +} + +func (c *chatAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) { + response.(*schemas.BifrostChatResponse).Usage = usage +} + // createChatResponseWithExecutedToolsAndNonAutoExecutableCalls creates a chat response // that includes executed tool results and non-auto-executable tool calls. The response // contains a formatted text summary of executed tool results and includes the non-auto-executable @@ -390,6 +404,14 @@ func (r *responsesAPIAdapter) createResponseWithExecutedTools( ) } +func (r *responsesAPIAdapter) extractUsage(response interface{}) *schemas.BifrostLLMUsage { + return response.(*schemas.BifrostResponsesResponse).Usage.ToBifrostLLMUsage() +} + +func (r *responsesAPIAdapter) applyUsage(response interface{}, usage *schemas.BifrostLLMUsage) { + response.(*schemas.BifrostResponsesResponse).Usage = usage.ToResponsesResponseUsage() +} + // createResponsesResponseWithExecutedToolsAndNonAutoExecutableCalls creates a responses response // that includes executed tool results and non-auto-executable tool calls. The response // contains a formatted text summary of executed tool results and includes the non-auto-executable diff --git a/core/mcp/clientmanager.go b/core/mcp/clientmanager.go index 36b12243da..b6bd442c20 100644 --- a/core/mcp/clientmanager.go +++ b/core/mcp/clientmanager.go @@ -118,6 +118,33 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { // This is to avoid deadlocks when the connection attempt is made m.mu.Unlock() + // Per-user OAuth: skip persistent connection. Auth is per-request at runtime. + // The admin verifies the configuration via a sample login before this is called, + // and tools are populated separately via SetClientTools(). + if configCopy.AuthType == schemas.MCPAuthTypePerUserOauth { + m.mu.Lock() + if client, exists := m.clientMap[config.ID]; exists { + if config.ConnectionString != nil { + url := config.ConnectionString.GetValue() + client.ConnectionInfo.ConnectionURL = &url + } + // Restore discovered tools from config (persisted in DB across restarts) + if len(config.DiscoveredTools) > 0 { + for toolName, tool := range config.DiscoveredTools { + client.ToolMap[toolName] = tool + } + client.ToolNameMapping = config.DiscoveredToolNameMapping + client.State = schemas.MCPConnectionStateConnected + m.logger.Info("%s Per-user OAuth MCP client '%s' restored with %d tools", MCPLogPrefix, config.Name, len(config.DiscoveredTools)) + } else { + client.State = schemas.MCPConnectionStatePendingTools + m.logger.Info("%s Per-user OAuth MCP client '%s' registered (connection deferred to runtime)", MCPLogPrefix, config.Name) + } + } + m.mu.Unlock() + return nil + } + // Connect using the copied config if err := m.connectToMCPClient(configCopy); err != nil { // Clean up the failed entry — this is a user-initiated action (UI/API), @@ -131,6 +158,92 @@ func (m *MCPManager) AddClient(config *schemas.MCPClientConfig) error { return nil } +// VerifyPerUserOAuthConnection creates a temporary MCP connection using the +// provided access token to verify the server is reachable and discover available +// tools. The connection is closed after verification. This is used during +// per-user OAuth client setup when the admin does a test login to validate the +// OAuth configuration before saving the MCP client. +// +// Parameters: +// - config: MCP client configuration (connection URL, name, etc.) +// - accessToken: temporary OAuth access token from the admin's test login +// +// Returns: +// - map[string]schemas.ChatTool: discovered tools keyed by prefixed name +// - map[string]string: tool name mapping (sanitized → original MCP name) +// - error: any error during verification +func (m *MCPManager) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { + if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { + return nil, nil, fmt.Errorf("connection URL is required for per-user OAuth verification") + } + + // Create HTTP transport with the admin's temporary Bearer token + headers := map[string]string{ + "Authorization": "Bearer " + accessToken, + } + httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) + if err != nil { + return nil, nil, fmt.Errorf("failed to create HTTP transport for verification: %w", err) + } + + // Create temporary MCP client + tempClient := client.NewClient(httpTransport) + ctx, cancel := context.WithTimeout(ctx, MCPClientConnectionEstablishTimeout) + defer cancel() + + // Start transport + if err := tempClient.Start(ctx); err != nil { + return nil, nil, fmt.Errorf("failed to start MCP connection for verification: %w", err) + } + defer tempClient.Close() + + // Initialize MCP handshake + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s-verify", config.Name), + Version: "1.0.0", + }, + }, + } + if _, err := tempClient.Initialize(ctx, initRequest); err != nil { + return nil, nil, fmt.Errorf("failed to initialize MCP connection for verification: %w", err) + } + + // Discover tools + tools, toolNameMapping, err := retrieveExternalTools(ctx, tempClient, config.Name, m.logger) + if err != nil { + return nil, nil, fmt.Errorf("failed to discover tools during verification: %w", err) + } + + m.logger.Info("%s Per-user OAuth verification succeeded for '%s': discovered %d tools", MCPLogPrefix, config.Name, len(tools)) + return tools, toolNameMapping, nil +} + +// SetClientTools updates the tool map and name mapping for an existing client. +// This is used to populate tools discovered during per-user OAuth verification, +// where tool discovery happens separately from client creation. +// +// Parameters: +// - clientID: ID of the client to update +// - tools: discovered tools keyed by prefixed name +// - toolNameMapping: mapping from sanitized tool names to original MCP names +func (m *MCPManager) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { + m.mu.Lock() + defer m.mu.Unlock() + + if client, exists := m.clientMap[clientID]; exists { + for toolName, tool := range tools { + client.ToolMap[toolName] = tool + } + client.ToolNameMapping = toolNameMapping + client.State = schemas.MCPConnectionStateConnected + m.logger.Debug("%s Set %d tools on client '%s'", MCPLogPrefix, len(tools), client.Name) + } +} + // RemoveClient removes an MCP client from the manager. // It handles cleanup for all transport types (HTTP, STDIO, SSE). // @@ -243,13 +356,15 @@ func (m *MCPManager) UpdateClient(id string, updatedConfig *schemas.MCPClientCon ConfigHash: client.ExecutionConfig.ConfigHash, ToolPricing: maps.Clone(client.ExecutionConfig.ToolPricing), // Updatable fields - copy from updated config with proper cloning - Name: updatedConfig.Name, - IsCodeModeClient: updatedConfig.IsCodeModeClient, - Headers: maps.Clone(updatedConfig.Headers), - ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), - ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), - IsPingAvailable: updatedConfig.IsPingAvailable, - ToolSyncInterval: updatedConfig.ToolSyncInterval, + Name: updatedConfig.Name, + IsCodeModeClient: updatedConfig.IsCodeModeClient, + Headers: maps.Clone(updatedConfig.Headers), + ToolsToExecute: slices.Clone(updatedConfig.ToolsToExecute), + ToolsToAutoExecute: slices.Clone(updatedConfig.ToolsToAutoExecute), + AllowedExtraHeaders: slices.Clone(updatedConfig.AllowedExtraHeaders), + IsPingAvailable: updatedConfig.IsPingAvailable, + ToolSyncInterval: updatedConfig.ToolSyncInterval, + AllowOnAllVirtualKeys: updatedConfig.AllowOnAllVirtualKeys, } // Atomically replace the config pointer @@ -663,7 +778,11 @@ func (m *MCPManager) connectToMCPClient(config *schemas.MCPClientConfig) error { } // Start health monitoring for the client - monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, config.IsPingAvailable, m.logger) + isPingAvailable := true + if config.IsPingAvailable != nil { + isPingAvailable = *config.IsPingAvailable + } + monitor := NewClientHealthMonitor(m, config.ID, DefaultHealthCheckInterval, isPingAvailable, m.logger) m.healthMonitorManager.StartMonitoring(monitor) // Start tool syncing for the client (skip for internal bifrost client) diff --git a/core/mcp/codemode.go b/core/mcp/codemode.go index e81c984195..fa11e52d0b 100644 --- a/core/mcp/codemode.go +++ b/core/mcp/codemode.go @@ -3,7 +3,6 @@ package mcp import ( - "context" "sync" "time" @@ -31,7 +30,7 @@ type CodeMode interface { // ExecuteTool handles a code mode tool call by name. // Returns the response message and any error that occurred. - ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) + ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) // IsCodeModeTool returns true if the given tool name is a code mode tool. IsCodeModeTool(toolName string) bool diff --git a/core/mcp/codemode/starlark/executecode.go b/core/mcp/codemode/starlark/executecode.go index da497d8f6f..d2d9435764 100644 --- a/core/mcp/codemode/starlark/executecode.go +++ b/core/mcp/codemode/starlark/executecode.go @@ -5,7 +5,6 @@ package starlark import ( "context" "fmt" - "net/http" "strings" "time" @@ -13,9 +12,11 @@ import ( "github.com/mark3labs/mcp-go/mcp" codemcp "github.com/maximhq/bifrost/core/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" "go.starlark.net/starlarkstruct" + "go.starlark.net/syntax" ) // ExecutionResult represents the result of code execution @@ -52,8 +53,11 @@ type ExecutionEnvironment struct { func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { executeToolCodeProps := schemas.NewOrderedMapFromPairs( schemas.KV("code", map[string]interface{}{ - "type": "string", - "description": "Python code to execute. The code runs in a Starlark interpreter (Python subset). Tool calls are synchronous - no async/await needed. For loops/conditionals, wrap in a function. Use print() for logging. ALWAYS retry if code fails. Example: def main():\n items = server.list_items()\n for item in items:\n print(item)\nresult = main()", + "type": "string", + "description": "Python (Starlark) code to execute. Tool calls are synchronous: result = server.tool(param=\"value\"). " + + "Use print() for logging. Assign to 'result' variable to return a value. " + + "Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations. " + + "Example: items = server.list_items()\nfor item in items:\n print(item[\"name\"])\nresult = items", }), ) return schemas.ChatTool{ @@ -61,36 +65,36 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { Function: &schemas.ChatToolFunction{ Name: codemcp.ToolTypeExecuteToolCode, Description: schemas.Ptr( - "Executes Python code inside a sandboxed Starlark interpreter with access to all connected MCP servers' tools. " + - "All connected servers are exposed as global objects named after their configuration keys, and each server " + - "provides functions for every tool available on that server. The canonical usage pattern is: " + - "result = .(param=\"value\"). Both and should be discovered " + - "using listToolFiles and readToolFile. " + - - "IMPORTANT WORKFLOW: Always follow this order — first use listToolFiles to see available servers and tools, " + - "then use readToolFile to understand the tool definitions and their parameters, and finally use executeToolCode " + - "to execute your code. " + + "Executes Python code in a sandboxed Starlark interpreter with MCP server tool access. " + + "Servers are exposed as global objects: result = serverName.toolName(param=\"value\"). " + + "This is the final step of the four-tool code mode workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "If you have not already read a tool's .pyi stub in this conversation, do that before writing code. " + + "Do NOT guess callable tool names from natural language or stale assumptions; use the exact identifier returned by listToolFiles/readToolFile. " + + + "STARLARK DIFFERENCES FROM PYTHON — READ BEFORE WRITING CODE: " + + "1. NO try/except/finally/raise — error handling is not supported, and tool failures cannot be caught inside Starlark. " + + "2. NO classes — use dicts and functions. " + + "3. NO imports, direct network access, or direct filesystem access — use MCP tools instead. " + + "4. NO is operator — use == for comparison. " + + "5. NO f-strings — use % formatting: \"Hello %s, count=%d\" % (name, n). " + + "6. Each executeToolCode call runs in a FRESH ISOLATED SCOPE — no variables, functions, or state persist between calls. Re-fetch data or store it via MCP tools (e.g., SQLite, FileSystem) if needed across calls. " + "SYNTAX NOTES: " + - "• Tool calls are synchronous - NO async/await needed, just call directly: result = server.tool(arg=\"value\") " + + "• Synchronous calls — NO async/await: result = server.tool(arg=\"value\") " + "• Use keyword arguments: server.tool(param=\"value\") NOT server.tool({\"param\": \"value\"}) " + "• Access dict values with brackets: result[\"key\"] NOT result.key " + - "• Use print() for logging (not console.log) " + - "• List comprehensions work: [x for x in items if x[\"active\"]] " + - "• To return a value, assign to 'result' variable: result = computed_value " + - "• CRITICAL: for/if/while at top level MUST be inside a function - def main(): ... then result = main() " + - - "RETRY POLICY: ALWAYS retry if a code block fails. Analyze the error, adjust your code, and retry. " + - - "The environment is intentionally minimal: " + - "• No imports needed or supported " + - "• No network APIs (use MCP tools for external interactions) " + - "• No file system access (use MCP tools) " + - "• No classes (use dicts and functions) " + - "• Deterministic execution (no random, no time) " + - - "Long-running operations are interrupted via execution timeout. " + - "This tool is designed specifically for orchestrating MCP tool calls and lightweight computation.", + "• Use print() for logging/debugging " + + "• List comprehensions: [x for x in items if x[\"active\"]] " + + "• String escapes work normally: \"line1\\nline2\" produces a newline " + + "• Triple-quoted strings for multiline: \"\"\"multi\\nline\"\"\" " + + "• chr(10) for newline character, chr(9) for tab " + + "• To return a value, assign to 'result': result = computed_value " + + "• MCP tool calls are timeout-limited; avoid long or infinite loops " + + + "AVAILABLE BUILTINS: print, len, range, enumerate, zip, sorted, reversed, min, max, " + + "int, float, str, bool, list, dict, tuple, set, hasattr, getattr, type, chr, ord, any, all, hash, repr. " + + + "RETRY POLICY: Retry after fixing syntax or logic errors, especially for read-only flows. Before rerunning code that already made tool calls, inspect prior outputs and avoid replaying stateful operations.", ), Parameters: &schemas.ToolFunctionParameters{ @@ -103,7 +107,7 @@ func (s *StarlarkCodeMode) createExecuteToolCodeTool() schemas.ChatTool { } // handleExecuteToolCode handles the executeToolCode tool call. -func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) handleExecuteToolCode(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { toolName := "unknown" if toolCall.Function.Name != nil { toolName = *toolCall.Function.Name @@ -197,16 +201,13 @@ func (s *StarlarkCodeMode) handleExecuteToolCode(ctx context.Context, toolCall s } // executeCode executes Python (Starlark) code in a sandboxed interpreter with MCP tool bindings. -func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) ExecutionResult { +func (s *StarlarkCodeMode) executeCode(ctx *schemas.BifrostContext, code string) ExecutionResult { logs := []string{} s.logger.Debug("%s Starting Starlark code execution", codemcp.CodeModeLogPrefix) - // Step 1: Convert literal \n escape sequences to actual newlines - codeWithNewlines := strings.ReplaceAll(code, "\\n", "\n") - - // Step 2: Handle empty code - trimmedCode := strings.TrimSpace(codeWithNewlines) + // Step 1: Handle empty code + trimmedCode := strings.TrimSpace(code) if trimmedCode == "" { return ExecutionResult{ Result: nil, @@ -218,7 +219,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } } - // Step 3: Build tool bindings for all connected servers + // Step 2: Build tool bindings for all connected servers availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) serverKeys := make([]string, 0, len(availableToolsPerClient)) predeclared := starlark.StringDict{} @@ -254,9 +255,8 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } originalToolName := tool.Function.Name - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - parsedToolName := parseToolName(unprefixedToolName) + parsedToolName := getCanonicalToolName(clientName, originalToolName) + compatibilityAlias := getCompatibilityToolAlias(clientName, originalToolName) s.logger.Debug("%s [%s] Binding tool: %s -> %s", codemcp.CodeModeLogPrefix, clientName, originalToolName, parsedToolName) @@ -298,6 +298,13 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi }) structMembers[parsedToolName] = toolFunc + + if compatibilityAlias != parsedToolName && isValidStarlarkIdentifier(compatibilityAlias) { + if _, exists := structMembers[compatibilityAlias]; !exists { + structMembers[compatibilityAlias] = toolFunc + s.logger.Debug("%s [%s] Added compatibility alias: %s -> %s", codemcp.CodeModeLogPrefix, clientName, compatibilityAlias, parsedToolName) + } + } } // Create a struct for this server @@ -312,7 +319,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi s.logger.Debug("%s No servers available for code mode execution", codemcp.CodeModeLogPrefix) } - // Step 4: Create Starlark thread with print function and timeout + // Step 3: Create Starlark thread with print function and timeout toolExecutionTimeout := s.getToolExecutionTimeout() timeoutCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) defer cancel() @@ -324,11 +331,26 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi }, } - // Set up cancellation check + // Set up cancellation check — watch the context and cancel the Starlark + // thread so that infinite loops and other long-running scripts are interrupted + // when the execution timeout fires. thread.SetLocal("context", timeoutCtx) + go func() { + <-timeoutCtx.Done() + thread.Cancel(timeoutCtx.Err().Error()) + }() + + // Step 4: Configure Starlark dialect options for a Python-like experience + starlarkOpts := &syntax.FileOptions{ + TopLevelControl: true, // allow if/for/while at top level (not just inside functions) + While: true, // enable while loops + Set: true, // enable set() builtin + GlobalReassign: true, // allow reassignment to top-level names + Recursion: true, // allow recursive functions + } // Step 5: Execute the code - globals, err := starlark.ExecFile(thread, "code.star", trimmedCode, predeclared) + globals, err := starlark.ExecFileOptions(starlarkOpts, thread, "code.star", trimmedCode, predeclared) if err != nil { errorMessage := err.Error() @@ -372,7 +394,7 @@ func (s *StarlarkCodeMode) executeCode(ctx context.Context, code string) Executi } // callMCPTool calls an MCP tool and returns the result. -func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { +func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { // Get available tools per client availableToolsPerClient := s.clientManager.GetToolPerClient(ctx) @@ -400,29 +422,25 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Strip the client name prefix from tool name before calling MCP server originalToolName := stripClientPrefix(toolName, clientName) - // Get BifrostContext for plugin pipeline - var bifrostCtx *schemas.BifrostContext - var ok bool - if bifrostCtx, ok = ctx.(*schemas.BifrostContext); !ok { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + originalRequestID, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string) + if !ok { + originalRequestID = "" } - originalRequestID, _ := bifrostCtx.Value(schemas.BifrostContextKeyRequestID).(string) - // Generate new request ID for this nested tool call var newRequestID string if s.fetchNewRequestIDFunc != nil { - newRequestID = s.fetchNewRequestIDFunc(bifrostCtx) + newRequestID = s.fetchNewRequestIDFunc(ctx) } else { newRequestID = fmt.Sprintf("exec_%d_%s", time.Now().UnixNano(), toolName) } // Create new child context - deadline, hasDeadline := bifrostCtx.Deadline() + deadline, hasDeadline := ctx.Deadline() if !hasDeadline { deadline = schemas.NoDeadline } - nestedCtx := schemas.NewBifrostContext(bifrostCtx, deadline) + nestedCtx := schemas.NewBifrostContext(ctx, deadline) nestedCtx.SetValue(schemas.BifrostContextKeyRequestID, newRequestID) if originalRequestID != "" { nestedCtx.SetValue(schemas.BifrostContextKeyParentMCPRequestID, originalRequestID) @@ -451,13 +469,17 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName // Check if plugin pipeline is available if s.pluginPipelineProvider == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline provider is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline provider is nil") } // Get plugin pipeline and run hooks pipeline := s.pluginPipelineProvider() if pipeline == nil { - return s.callMCPToolDirect(ctx, client, originalToolName, clientName, toolName, args, appendLog) + // Should never happen, but just in case + s.logger.Warn("%s Plugin pipeline is nil", codemcp.CodeModeLogPrefix) + return nil, fmt.Errorf("plugin pipeline is nil") } defer s.releasePluginPipeline(pipeline) @@ -515,14 +537,7 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName Name: toolNameToCall, Arguments: args, }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers + Header: utils.GetHeadersForToolExecution(nestedCtx, client), } toolExecutionTimeout := s.getToolExecutionTimeout() @@ -604,57 +619,3 @@ func (s *StarlarkCodeMode) callMCPTool(ctx context.Context, clientName, toolName return nil, fmt.Errorf("plugin post-hooks returned invalid response") } - -// callMCPToolDirect executes an MCP tool call directly without plugin hooks. -func (s *StarlarkCodeMode) callMCPToolDirect(ctx context.Context, client *schemas.MCPClientState, originalToolName, clientName, toolName string, args map[string]interface{}, appendLog func(string)) (interface{}, error) { - callRequest := mcp.CallToolRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodToolsCall), - }, - Params: mcp.CallToolParams{ - Name: originalToolName, - Arguments: args, - }, - } - - if client.ExecutionConfig.Headers != nil { - headers := make(http.Header) - for key, value := range client.ExecutionConfig.Headers { - headers.Add(key, value.GetValue()) - } - callRequest.Header = headers - } - - toolExecutionTimeout := s.getToolExecutionTimeout() - toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) - defer cancel() - - logToolName := stripClientPrefix(toolName, clientName) - logToolName = strings.ReplaceAll(logToolName, "-", "_") - - toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest) - if callErr != nil { - s.logger.Debug("%s Tool call failed: %s.%s - %v", codemcp.CodeModeLogPrefix, clientName, logToolName, callErr) - appendLog(fmt.Sprintf("[TOOL] %s.%s error: %v", clientName, logToolName, callErr)) - return nil, fmt.Errorf("tool call failed for %s.%s: %v", clientName, logToolName, callErr) - } - - rawResult := extractTextFromMCPResponse(toolResponse, toolName) - - if after, ok := strings.CutPrefix(rawResult, "Error: "); ok { - errorMsg := after - s.logger.Debug("%s Tool returned error result: %s.%s - %s", codemcp.CodeModeLogPrefix, clientName, logToolName, errorMsg) - appendLog(fmt.Sprintf("[TOOL] %s.%s error result: %s", clientName, logToolName, errorMsg)) - return nil, fmt.Errorf("%s", errorMsg) - } - - var finalResult interface{} - if err := sonic.Unmarshal([]byte(rawResult), &finalResult); err != nil { - finalResult = rawResult - } - - resultStr := formatResultForLog(finalResult) - appendLog(fmt.Sprintf("[TOOL] %s.%s raw response: %s", clientName, logToolName, resultStr)) - - return finalResult, nil -} diff --git a/core/mcp/codemode/starlark/getdocs.go b/core/mcp/codemode/starlark/getdocs.go index 61e4b1dc86..ea622bf7cf 100644 --- a/core/mcp/codemode/starlark/getdocs.go +++ b/core/mcp/codemode/starlark/getdocs.go @@ -71,8 +71,6 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem var matchedTool *schemas.ChatTool serverNameLower := strings.ToLower(serverName) - toolNameLower := strings.ToLower(toolName) - for clientName, tools := range availableToolsPerClient { client := s.clientManager.GetClientByName(clientName) if client == nil { @@ -90,10 +88,7 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem // Find the specific tool for i, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { + if matchesToolReference(toolName, clientName, tool.Function.Name) { matchedTool = &tools[i] break } @@ -125,9 +120,7 @@ func (s *StarlarkCodeMode) handleGetToolDocs(ctx context.Context, toolCall schem var availableTools []string for _, tool := range tools { if tool.Function != nil { - unprefixedToolName := stripClientPrefix(tool.Function.Name, matchedClientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) + availableTools = append(availableTools, getCanonicalToolName(matchedClientName, tool.Function.Name)) } } errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools are:\n", toolName, matchedClientName) @@ -150,7 +143,7 @@ func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isTool // Write comprehensive header sb.WriteString("# ============================================================================\n") if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, tools[0].Function.Name)) + sb.WriteString(fmt.Sprintf("# Documentation for %s.%s tool\n", clientName, getCanonicalToolName(clientName, tools[0].Function.Name))) } else { sb.WriteString(fmt.Sprintf("# Documentation for %s MCP server\n", clientName)) } @@ -187,9 +180,7 @@ func generateTypeDefinitions(clientName string, tools []schemas.ChatTool, isTool } originalToolName := tool.Function.Name - unprefixedToolName := stripClientPrefix(originalToolName, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - toolName := parseToolName(unprefixedToolName) + toolName := getCanonicalToolName(clientName, originalToolName) description := "" if tool.Function.Description != nil { description = *tool.Function.Description diff --git a/core/mcp/codemode/starlark/listfiles.go b/core/mcp/codemode/starlark/listfiles.go index caff015194..4d6aa73add 100644 --- a/core/mcp/codemode/starlark/listfiles.go +++ b/core/mcp/codemode/starlark/listfiles.go @@ -21,7 +21,8 @@ func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { if bindingLevel == schemas.CodeModeBindingLevelServer { description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers. " + "Each server has a corresponding file (e.g., servers/.pyi) that contains compact Python signatures for all tools in that server. " + - "Use readToolFile to read a specific server file and see all available tools with their signatures. " + + "Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "Use readToolFile before executeToolCode to read a specific server file and confirm exact callable tool names and parameters. " + "Use getToolDocs if you need detailed documentation for a specific tool. " + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + @@ -30,7 +31,9 @@ func (s *StarlarkCodeMode) createListToolFilesTool() schemas.ChatTool { } else { description = "Returns a tree structure listing all virtual .pyi stub files available for connected MCP servers, organized by individual tool. " + "Each tool has a corresponding file (e.g., servers//.pyi) that contains compact Python signatures for that specific tool. " + - "Use readToolFile to read a specific tool file and see its signature. " + + "The shown in each filename is the exact canonical identifier exposed in executeToolCode. " + + "Safe workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + + "Use readToolFile before executeToolCode to confirm the exact signature and parameters for the tool you want to call. " + "Use getToolDocs if you need detailed documentation for a specific tool. " + "In code, access tools via: server_name.tool_name(param=value). " + "The server names used in code correspond to the human-readable names shown in this listing. " + @@ -88,12 +91,7 @@ func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall sch // Tool-level: one file per tool for _, tool := range tools { if tool.Function != nil && tool.Function.Name != "" { - // Strip the client prefix from tool name (format: "client-toolname" -> "toolname") - // But replace - with _ for valid Python identifiers - toolName := stripClientPrefix(tool.Function.Name, clientName) - // Replace any remaining hyphens with underscores for Python compatibility - toolName = strings.ReplaceAll(toolName, "-", "_") - // Validate normalized tool name to prevent path traversal + toolName := getCanonicalToolName(clientName, tool.Function.Name) if err := validateNormalizedToolName(toolName); err != nil { s.logger.Warn("%s Skipping tool '%s' from client '%s': %v", codemcp.CodeModeLogPrefix, tool.Function.Name, clientName, err) continue @@ -112,10 +110,32 @@ func (s *StarlarkCodeMode) handleListToolFiles(ctx context.Context, toolCall sch } // Build tree structure from file list - responseText := buildVFSTree(files) + responseText := buildListToolFilesResponse(files, bindingLevel) return createToolResponseMessage(toolCall, responseText), nil } +func buildListToolFilesResponse(files []string, bindingLevel schemas.CodeModeBindingLevel) string { + tree := buildVFSTree(files) + if tree == "" { + return "" + } + + header := []string{ + "# Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode", + } + + if bindingLevel == schemas.CodeModeBindingLevelServer { + header = append(header, "# Read the server .pyi file before executeToolCode to confirm exact tool names and parameters.") + } else { + header = append(header, + "# Filenames below use the exact canonical tool identifiers available in executeToolCode.", + "# Still call readToolFile before executeToolCode to confirm parameters and return shape.", + ) + } + + return strings.Join(append(header, "", tree), "\n") +} + // VFS tree node structure for building hierarchical file structure type treeNode struct { isDirectory bool diff --git a/core/mcp/codemode/starlark/readfile.go b/core/mcp/codemode/starlark/readfile.go index 5940e0c9fc..41063ad065 100644 --- a/core/mcp/codemode/starlark/readfile.go +++ b/core/mcp/codemode/starlark/readfile.go @@ -21,21 +21,23 @@ func (s *StarlarkCodeMode) createReadToolFileTool() schemas.ChatTool { var fileNameDescription, toolDescription string if bindingLevel == schemas.CodeModeBindingLevelServer { - fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'calculator.pyi')" + fileNameDescription = "The virtual filename from listToolFiles in format: servers/.pyi (e.g., 'servers/calculator.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific MCP server, returning compact Python function signatures " + "for all tools available on that server. The fileName should be in format servers/.pyi as listed by listToolFiles. " + "The function performs case-insensitive matching and removes the .pyi extension. " + + "This is the authoritative source for the exact callable tool names and parameters to use in executeToolCode. " + "Each tool can be accessed in code via: serverName.tool_name(param=value). " + "If the compact signature is not enough to understand a tool, use getToolDocs for detailed documentation. " + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + "do NOT call this tool again with startLine/endLine - you already have the complete file." } else { - fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'calculator/add.pyi')" + fileNameDescription = "The virtual filename from listToolFiles in format: servers//.pyi (e.g., 'servers/calculator/add.pyi')" toolDescription = "Reads a virtual .pyi stub file for a specific tool, returning its compact Python function signature. " + "The fileName should be in format servers//.pyi as listed by listToolFiles. " + "The function performs case-insensitive matching and removes the .pyi extension. " + - "The tool can be accessed in code via: serverName.tool_name(param=value). " + + "This is the authoritative source for the exact callable tool name and arguments to use in executeToolCode. " + + "The tool can be accessed in code via: serverName.tool_name(param=value) using the def name shown in the file. " + "If the compact signature is not enough to understand the tool, use getToolDocs for detailed documentation. " + "Workflow: listToolFiles -> readToolFile -> (optional) getToolDocs -> executeToolCode. " + "IMPORTANT: If the response header shows 'Total lines: X (this is the complete file)', " + @@ -126,13 +128,9 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche if isToolLevel { // Tool-level: filter to specific tool var foundTool *schemas.ChatTool - toolNameLower := strings.ToLower(toolName) for i, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for comparison - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - if strings.ToLower(unprefixedToolName) == toolNameLower { + if matchesToolReference(toolName, clientName, tool.Function.Name) { foundTool = &tools[i] break } @@ -143,15 +141,12 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche availableTools := make([]string, 0) for _, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableTools = append(availableTools, unprefixedToolName) + availableTools = append(availableTools, getCanonicalToolName(clientName, tool.Function.Name)) } } errorMsg := fmt.Sprintf("Tool '%s' not found in server '%s'. Available tools in this server are:\n", toolName, clientName) for _, t := range availableTools { - errorMsg += fmt.Sprintf(" - %s/%s.pyi\n", clientName, t) + errorMsg += fmt.Sprintf(" - servers/%s/%s.pyi\n", clientName, t) } return createToolResponseMessage(toolCall, errorMsg), nil } @@ -171,17 +166,14 @@ func (s *StarlarkCodeMode) handleReadToolFile(ctx context.Context, toolCall sche for name := range availableToolsPerClient { if bindingLevel == schemas.CodeModeBindingLevelServer { - availableFiles = append(availableFiles, fmt.Sprintf("%s.pyi", name)) + availableFiles = append(availableFiles, fmt.Sprintf("servers/%s.pyi", name)) } else { client := s.clientManager.GetClientByName(name) if client != nil && client.ExecutionConfig.IsCodeModeClient { if tools, ok := availableToolsPerClient[name]; ok { for _, tool := range tools { if tool.Function != nil { - // Strip client prefix and replace - with _ for display - unprefixedToolName := stripClientPrefix(tool.Function.Name, name) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - availableFiles = append(availableFiles, fmt.Sprintf("%s/%s.pyi", name, unprefixedToolName)) + availableFiles = append(availableFiles, fmt.Sprintf("servers/%s/%s.pyi", name, getCanonicalToolName(name, tool.Function.Name))) } } } @@ -295,12 +287,14 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo // Minimal header if isToolLevel && len(tools) == 1 && tools[0].Function != nil { - toolName := parseToolName(stripClientPrefix(tools[0].Function.Name, clientName)) + toolName := getCanonicalToolName(clientName, tools[0].Function.Name) sb.WriteString(fmt.Sprintf("# %s.%s tool\n", clientName, toolName)) } else { sb.WriteString(fmt.Sprintf("# %s server tools\n", clientName)) } sb.WriteString(fmt.Sprintf("# Usage: %s.tool_name(param=value)\n", clientName)) + sb.WriteString("# The def names below are the exact callable names to use in executeToolCode.\n") + sb.WriteString("# Read this file before executeToolCode to confirm parameters and return shape.\n") sb.WriteString(fmt.Sprintf("# For detailed docs: use getToolDocs(server=\"%s\", tool=\"tool_name\")\n", clientName)) sb.WriteString("# Note: Descriptions may be truncated. Use getToolDocs for full details.\n\n") @@ -309,10 +303,7 @@ func generateCompactSignatures(clientName string, tools []schemas.ChatTool, isTo continue } - // Strip client prefix and replace - with _ for code mode compatibility - unprefixedToolName := stripClientPrefix(tool.Function.Name, clientName) - unprefixedToolName = strings.ReplaceAll(unprefixedToolName, "-", "_") - toolName := parseToolName(unprefixedToolName) + toolName := getCanonicalToolName(clientName, tool.Function.Name) // Format inline parameters in Python style params := formatPythonParams(tool.Function.Parameters) diff --git a/core/mcp/codemode/starlark/starlark.go b/core/mcp/codemode/starlark/starlark.go index 0da1d2ccd9..348655b983 100644 --- a/core/mcp/codemode/starlark/starlark.go +++ b/core/mcp/codemode/starlark/starlark.go @@ -6,7 +6,6 @@ package starlark import ( - "context" "fmt" "sync" "sync/atomic" @@ -111,7 +110,7 @@ func (s *StarlarkCodeMode) GetTools() []schemas.ChatTool { // Returns: // - *schemas.ChatMessage: The tool response message // - error: Any error that occurred during execution -func (s *StarlarkCodeMode) ExecuteTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { +func (s *StarlarkCodeMode) ExecuteTool(ctx *schemas.BifrostContext, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { if toolCall.Function.Name == nil { return nil, fmt.Errorf("tool call missing function name") } diff --git a/core/mcp/codemode/starlark/starlark_test.go b/core/mcp/codemode/starlark/starlark_test.go index dba557f88a..a48e77a887 100644 --- a/core/mcp/codemode/starlark/starlark_test.go +++ b/core/mcp/codemode/starlark/starlark_test.go @@ -3,13 +3,42 @@ package starlark import ( + "context" + "strings" "testing" + "time" "github.com/bytedance/sonic" + codemcp "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/schemas" "go.starlark.net/starlark" + "go.starlark.net/syntax" ) +type testClientManager struct { + clients map[string]*schemas.MCPClientState + tools map[string][]schemas.ChatTool +} + +func (m *testClientManager) GetClientForTool(toolName string) *schemas.MCPClientState { + for clientName, tools := range m.tools { + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name == toolName { + return m.clients[clientName] + } + } + } + return nil +} + +func (m *testClientManager) GetClientByName(clientName string) *schemas.MCPClientState { + return m.clients[clientName] +} + +func (m *testClientManager) GetToolPerClient(ctx context.Context) map[string][]schemas.ChatTool { + return m.tools +} + func TestStarlarkToGo(t *testing.T) { t.Run("Convert None", func(t *testing.T) { result := starlarkToGo(starlark.None) @@ -151,6 +180,83 @@ func TestGoToStarlark(t *testing.T) { }) } +func TestGetCanonicalToolName(t *testing.T) { + if got := getCanonicalToolName("github", "github-SEARCH_REPOS"); got != "search_repos" { + t.Fatalf("expected canonical tool name search_repos, got %q", got) + } + + if got := getCanonicalToolName("math", "math-123Add!"); got != "_123add" { + t.Fatalf("expected canonical tool name _123add, got %q", got) + } +} + +func TestMatchesToolReferenceSupportsCanonicalAndLegacyNames(t *testing.T) { + clientName := "github" + originalToolName := "github-SEARCH_REPOS" + + testCases := []string{ + "search_repos", + "SEARCH_REPOS", + } + + for _, toolRef := range testCases { + if !matchesToolReference(toolRef, clientName, originalToolName) { + t.Fatalf("expected %q to match %q", toolRef, originalToolName) + } + } +} + +func TestHandleListToolFilesUsesCanonicalToolIdentifiers(t *testing.T) { + mode := NewStarlarkCodeMode(&codemcp.CodeModeConfig{ + BindingLevel: schemas.CodeModeBindingLevelTool, + ToolExecutionTimeout: time.Second, + }, nil) + + clientName := "github" + mode.clientManager = &testClientManager{ + clients: map[string]*schemas.MCPClientState{ + clientName: { + Name: clientName, + ExecutionConfig: &schemas.MCPClientConfig{ + Name: clientName, + IsCodeModeClient: true, + }, + }, + }, + tools: map[string][]schemas.ChatTool{ + clientName: { + { + Function: &schemas.ChatToolFunction{ + Name: "github-SEARCH_REPOS", + }, + }, + }, + }, + } + + msg, err := mode.handleListToolFiles(context.Background(), schemas.ChatAssistantMessageToolCall{ + ID: schemas.Ptr("tool-call-1"), + }) + if err != nil { + t.Fatalf("handleListToolFiles returned error: %v", err) + } + + if msg == nil || msg.Content == nil || msg.Content.ContentStr == nil { + t.Fatal("expected tool response content") + } + + content := *msg.Content.ContentStr + if !strings.Contains(content, "search_repos.pyi") { + t.Fatalf("expected canonical tool file path in response, got:\n%s", content) + } + if strings.Contains(content, "SEARCH_REPOS.pyi") { + t.Fatalf("did not expect raw uppercase tool file path in response, got:\n%s", content) + } + if !strings.Contains(content, "readToolFile before executeToolCode") { + t.Fatalf("expected workflow guidance in response, got:\n%s", content) + } +} + func TestGeneratePythonErrorHints(t *testing.T) { serverKeys := []string{"calculator", "weather"} @@ -161,13 +267,13 @@ func TestGeneratePythonErrorHints(t *testing.T) { } found := false for _, hint := range hints { - if containsAny(hint, "not defined", "undefined") { + if strings.Contains(hint, "Variable 'foo' is not defined.") { found = true break } } if !found { - t.Error("Expected hint about undefined variable") + t.Errorf("Expected exact undefined variable hint for foo, got: %v", hints) } }) @@ -489,3 +595,405 @@ func TestFormatResultForLog(t *testing.T) { } }) } + +// starlarkOpts returns the FileOptions used by the code mode executor. +// Kept in sync with executecode.go to test the same dialect configuration. +func starlarkOpts() *syntax.FileOptions { + return &syntax.FileOptions{ + TopLevelControl: true, + While: true, + Set: true, + GlobalReassign: true, + Recursion: true, + } +} + +// execStarlark is a test helper that executes Starlark code with our dialect options +// and returns the globals and any error. +func execStarlark(code string) (starlark.StringDict, error) { + thread := &starlark.Thread{Name: "test"} + return starlark.ExecFileOptions(starlarkOpts(), thread, "test.star", code, nil) +} + +func TestStarlarkDialectOptions(t *testing.T) { + t.Run("Top-level for loop", func(t *testing.T) { + code := ` +items = [] +for i in range(3): + items.append(i) +result = items +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level for loop should work, got error: %v", err) + } + resultVal := globals["result"] + list, ok := resultVal.(*starlark.List) + if !ok { + t.Fatalf("Expected list, got %T", resultVal) + } + if list.Len() != 3 { + t.Errorf("Expected 3 items, got %d", list.Len()) + } + }) + + t.Run("Top-level if statement", func(t *testing.T) { + code := ` +x = 10 +if x > 5: + result = "big" +else: + result = "small" +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level if should work, got error: %v", err) + } + if globals["result"] != starlark.String("big") { + t.Errorf("Expected 'big', got %v", globals["result"]) + } + }) + + t.Run("Top-level while loop", func(t *testing.T) { + code := ` +count = 0 +while count < 5: + count += 1 +result = count +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Top-level while loop should work, got error: %v", err) + } + resultVal := globals["result"] + if resultVal.String() != "5" { + t.Errorf("Expected 5, got %v", resultVal) + } + }) + + t.Run("While loop inside function", func(t *testing.T) { + code := ` +def countdown(n): + items = [] + while n > 0: + items.append(n) + n -= 1 + return items +result = countdown(3) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("While in function should work, got error: %v", err) + } + list := globals["result"].(*starlark.List) + if list.Len() != 3 { + t.Errorf("Expected 3 items, got %d", list.Len()) + } + }) + + t.Run("set() builtin", func(t *testing.T) { + code := ` +s = set([1, 2, 3, 2, 1]) +result = len(s) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("set() should work, got error: %v", err) + } + if globals["result"].String() != "3" { + t.Errorf("Expected 3 unique items, got %v", globals["result"]) + } + }) + + t.Run("Global variable reassignment", func(t *testing.T) { + code := ` +x = 1 +x = x + 1 +x = x * 3 +result = x +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Global reassignment should work, got error: %v", err) + } + if globals["result"].String() != "6" { + t.Errorf("Expected 6, got %v", globals["result"]) + } + }) + + t.Run("Recursive function", func(t *testing.T) { + code := ` +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) +result = factorial(5) +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Recursion should work, got error: %v", err) + } + if globals["result"].String() != "120" { + t.Errorf("Expected 120, got %v", globals["result"]) + } + }) +} + +func TestStarlarkStringEscapePreservation(t *testing.T) { + t.Run("Backslash-n in string literal preserved", func(t *testing.T) { + // Simulate what happens after JSON deserialization: + // Model writes: {"code": "msg = \"hello\\nworld\""} + // sonic.Unmarshal produces: msg = "hello\nworld" (where \n is two chars: \ + n) + // Starlark should interpret \n as newline escape inside the string + code := "msg = \"hello\\nworld\"\nresult = msg" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("String with \\n escape should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "hello\nworld" { + t.Errorf("Expected 'helloworld', got %q", resultStr) + } + }) + + t.Run("Multiple escape sequences in strings", func(t *testing.T) { + code := "msg = \"col1\\tcol2\\nrow1\\trow2\"\nresult = msg" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("String with multiple escapes should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "col1\tcol2\nrow1\trow2" { + t.Errorf("Expected tab/newline escapes, got %q", resultStr) + } + }) + + t.Run("Newline join pattern", func(t *testing.T) { + // This is the exact pattern that failed 7 times in benchmarks + code := ` +def main(): + lines = ["line1", "line2", "line3"] + content = "\n".join(lines) + return content +result = main() +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Newline join pattern should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "line1\nline2\nline3" { + t.Errorf("Expected joined lines, got %q", resultStr) + } + }) + + t.Run("chr() for newline", func(t *testing.T) { + code := ` +nl = chr(10) +result = "hello" + nl + "world" +` + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("chr(10) should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "hello\nworld" { + t.Errorf("Expected 'helloworld', got %q", resultStr) + } + }) + + t.Run("Triple-quoted strings", func(t *testing.T) { + code := "result = \"\"\"line1\nline2\nline3\"\"\"" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Triple-quoted string should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "line1\nline2\nline3" { + t.Errorf("Expected multiline string, got %q", resultStr) + } + }) + + t.Run("Raw string preserves backslash", func(t *testing.T) { + code := "result = r\"hello\\nworld\"" + + globals, err := execStarlark(code) + if err != nil { + t.Fatalf("Raw string should work, got error: %v", err) + } + resultStr := string(globals["result"].(starlark.String)) + // Raw string: \n stays as two characters \ and n + if resultStr != "hello\\nworld" { + t.Errorf("Expected literal backslash-n, got %q", resultStr) + } + }) + + t.Run("JSON deserialization then Starlark execution", func(t *testing.T) { + // End-to-end: simulate the exact flow from model JSON → sonic.Unmarshal → Starlark + jsonArgs := `{"code": "lines = [\"a\", \"b\", \"c\"]\nresult = \"\\n\".join(lines)"}` + + var arguments map[string]interface{} + err := sonic.Unmarshal([]byte(jsonArgs), &arguments) + if err != nil { + t.Fatalf("JSON unmarshal failed: %v", err) + } + + code := arguments["code"].(string) + + globals, starlarkErr := execStarlark(code) + if starlarkErr != nil { + t.Fatalf("Starlark execution failed: %v", starlarkErr) + } + resultStr := string(globals["result"].(starlark.String)) + if resultStr != "a\nb\nc" { + t.Errorf("Expected 'abc', got %q", resultStr) + } + }) +} + +func TestStarlarkUnsupportedFeatures(t *testing.T) { + t.Run("try/except rejected", func(t *testing.T) { + code := ` +def main(): + try: + x = 1 + except: + x = 0 +result = main() +` + _, err := execStarlark(code) + if err == nil { + t.Fatal("try/except should be rejected by Starlark") + } + if !strings.Contains(err.Error(), "got try") { + t.Errorf("Expected 'got try' in error, got: %v", err) + } + }) + + t.Run("raise rejected", func(t *testing.T) { + code := `raise ValueError("test")` + + _, err := execStarlark(code) + if err == nil { + t.Fatal("raise should be rejected by Starlark") + } + }) + + t.Run("class rejected", func(t *testing.T) { + code := ` +class Foo: + pass +` + _, err := execStarlark(code) + if err == nil { + t.Fatal("class should be rejected by Starlark") + } + }) + + t.Run("import rejected", func(t *testing.T) { + code := `import json` + + _, err := execStarlark(code) + if err == nil { + t.Fatal("import should be rejected by Starlark") + } + }) +} + +func TestGeneratePythonErrorHintsNewCases(t *testing.T) { + serverKeys := []string{"Github", "SqLite"} + + t.Run("try/except hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:3:9: got try, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for try/except error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about try/except not being supported, got: %v", hints) + } + }) + + t.Run("except hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:5:9: got except, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for except error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("finally hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:7:9: got finally, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for finally error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("raise hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:2:1: got raise, want primary expression", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for raise error") + } + found := false + for _, hint := range hints { + if containsAny(hint, "try/except", "exception handling") { + found = true + break + } + } + if !found { + t.Errorf("Expected hint about exception handling, got: %v", hints) + } + }) + + t.Run("Undefined variable includes scope hint", func(t *testing.T) { + hints := generatePythonErrorHints("code.star:3:17: undefined: commits_n8n", serverKeys) + if len(hints) == 0 { + t.Fatal("Expected hints for undefined variable") + } + foundVar := false + foundScope := false + for _, hint := range hints { + if strings.Contains(hint, "Variable 'commits_n8n' is not defined.") { + foundVar = true + } + if containsAny(hint, "fresh scope", "persist") { + foundScope = true + } + } + if !foundVar { + t.Errorf("Expected exact undefined variable hint for commits_n8n, got: %v", hints) + } + if !foundScope { + t.Errorf("Expected scope persistence hint, got: %v", hints) + } + }) +} diff --git a/core/mcp/codemode/starlark/utils.go b/core/mcp/codemode/starlark/utils.go index 5b7c9ab920..aea6e732d3 100644 --- a/core/mcp/codemode/starlark/utils.go +++ b/core/mcp/codemode/starlark/utils.go @@ -191,11 +191,25 @@ func formatResultForLog(result interface{}) string { func generatePythonErrorHints(errorMessage string, serverKeys []string) []string { hints := []string{} - if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { - re := regexp.MustCompile(`(\w+).*(?:undefined|not defined)`) - if match := re.FindStringSubmatch(errorMessage); len(match) > 1 { - undefinedVar := match[1] + if strings.Contains(errorMessage, "got try") || strings.Contains(errorMessage, "got except") || + strings.Contains(errorMessage, "got finally") || strings.Contains(errorMessage, "got raise") { + hints = append(hints, "Starlark does NOT support try/except/finally/raise — there is no exception handling.") + hints = append(hints, "Instead, check return values for errors:") + hints = append(hints, " result = server.tool(param=\"value\")") + hints = append(hints, " if result == None or (type(result) == \"dict\" and \"error\" in result):") + hints = append(hints, " print(\"Error:\", result)") + } else if strings.Contains(errorMessage, "undefined") || strings.Contains(errorMessage, "not defined") { + var undefinedVar string + if match := regexp.MustCompile(`name ['"]([^'"]+)['"] is not defined`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } else if match := regexp.MustCompile(`undefined:\s*([A-Za-z_][A-Za-z0-9_]*)`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } else if match := regexp.MustCompile(`([A-Za-z_][A-Za-z0-9_]*)[^A-Za-z0-9_]+(?:undefined|not defined)`).FindStringSubmatch(errorMessage); len(match) > 1 { + undefinedVar = match[1] + } + if undefinedVar != "" { hints = append(hints, fmt.Sprintf("Variable '%s' is not defined.", undefinedVar)) + hints = append(hints, "Note: Each executeToolCode call runs in a fresh scope — no variables persist between calls.") if len(serverKeys) > 0 { hints = append(hints, fmt.Sprintf("Available server keys: %s", strings.Join(serverKeys, ", "))) hints = append(hints, "Access tools using: server_name.tool_name(param=\"value\")") @@ -298,7 +312,7 @@ func createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, re } } -// parseToolName parses the tool name to be JavaScript-compatible. +// parseToolName normalizes a raw tool name into a Starlark-compatible identifier. func parseToolName(toolName string) string { if toolName == "" { return "" @@ -349,6 +363,61 @@ func parseToolName(toolName string) string { return parsed } +// getCanonicalToolName returns the exact callable tool identifier exposed in Starlark. +func getCanonicalToolName(clientName, originalToolName string) string { + return parseToolName(stripClientPrefix(originalToolName, clientName)) +} + +// getCompatibilityToolAlias returns the case-preserving alias derived from the raw tool name. +// This is used as a compatibility alias when the raw name is still a valid Starlark identifier. +func getCompatibilityToolAlias(clientName, originalToolName string) string { + return strings.ReplaceAll(stripClientPrefix(originalToolName, clientName), "-", "_") +} + +// matchesToolReference reports whether the requested tool name matches any supported identifier form. +// We accept the canonical callable name plus legacy display forms for backward compatibility. +func matchesToolReference(requestedToolName, clientName, originalToolName string) bool { + requested := strings.ToLower(requestedToolName) + if requested == "" { + return false + } + + candidates := []string{ + getCanonicalToolName(clientName, originalToolName), + getCompatibilityToolAlias(clientName, originalToolName), + stripClientPrefix(originalToolName, clientName), + } + + for _, candidate := range candidates { + if candidate != "" && requested == strings.ToLower(candidate) { + return true + } + } + + return false +} + +// isValidStarlarkIdentifier reports whether name can be used directly in Starlark code. +func isValidStarlarkIdentifier(name string) bool { + if name == "" { + return false + } + + runes := []rune(name) + first := runes[0] + if !unicode.IsLetter(first) && first != '_' && first != '$' { + return false + } + + for _, r := range runes[1:] { + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' && r != '$' { + return false + } + } + + return true +} + // validateNormalizedToolName validates a normalized tool name to prevent path traversal. func validateNormalizedToolName(normalizedName string) error { if normalizedName == "" { diff --git a/core/mcp/healthmonitor.go b/core/mcp/healthmonitor.go index aa6595fe7a..85769afdcb 100644 --- a/core/mcp/healthmonitor.go +++ b/core/mcp/healthmonitor.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/mcp" "github.com/maximhq/bifrost/core/schemas" ) @@ -140,9 +141,14 @@ func (chm *ClientHealthMonitor) performHealthCheck() { } chm.mu.Unlock() - // Get the client connection + // Get the client connection — capture Conn while holding the lock so we + // don't race with removeClientUnsafe zeroing it under the write lock. chm.manager.mu.RLock() clientState, exists := chm.manager.clientMap[chm.clientID] + var conn *client.Client + if exists && clientState != nil { + conn = clientState.Conn + } chm.manager.mu.RUnlock() if !exists { @@ -151,7 +157,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { } var err error - if clientState.Conn == nil { + if conn == nil { // No active connection — treat as a health check failure err = fmt.Errorf("no active connection") } else { @@ -160,7 +166,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { defer cancel() if chm.isPingAvailable { - err = clientState.Conn.Ping(ctx) + err = conn.Ping(ctx) } else { listRequest := mcp.ListToolsRequest{ PaginatedRequest: mcp.PaginatedRequest{ @@ -169,7 +175,7 @@ func (chm *ClientHealthMonitor) performHealthCheck() { }, }, } - _, err = clientState.Conn.ListTools(ctx, listRequest) + _, err = conn.ListTools(ctx, listRequest) } } diff --git a/core/mcp/interface.go b/core/mcp/interface.go index 93617ce511..3069e2692a 100644 --- a/core/mcp/interface.go +++ b/core/mcp/interface.go @@ -14,15 +14,17 @@ import ( type MCPManagerInterface interface { // Tool Operations // AddToolsToRequest parses available MCP tools and adds them to the request - AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest + AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest // GetAvailableTools returns all available MCP tools for the given context - GetAvailableTools(ctx context.Context) []schemas.ChatTool + GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool // ExecuteToolCall executes a single tool call and returns the result ExecuteToolCall(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) - // UpdateToolManagerConfig updates the configuration for the tool manager + // UpdateToolManagerConfig updates the configuration for the tool manager. + // DisableAutoToolInject in the config controls auto injection — pass the + // current value whenever only other fields change so it is never silently reset. UpdateToolManagerConfig(config *schemas.MCPToolManagerConfig) // Agent Mode Operations @@ -60,6 +62,14 @@ type MCPManagerInterface interface { // ReconnectClient reconnects an MCP client by ID ReconnectClient(id string) error + // VerifyPerUserOAuthConnection creates a temporary MCP connection using a + // test access token to verify connectivity and discover tools. The connection + // is closed after verification. + VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) + + // SetClientTools updates the tool map and name mapping for an existing client. + SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) + // Tool Registration // RegisterTool registers a local tool with the MCP server RegisterTool(name, description string, toolFunction MCPToolFunction[any], toolSchema schemas.ChatTool) error diff --git a/core/mcp/mcp.go b/core/mcp/mcp.go index b86409ef1a..adcf3158ac 100644 --- a/core/mcp/mcp.go +++ b/core/mcp/mcp.go @@ -21,13 +21,6 @@ const ( BifrostMCPClientKey = "bifrostInternal" // Key for internal Bifrost client in clientMap MCPLogPrefix = "[Bifrost MCP]" // Consistent logging prefix MCPClientConnectionEstablishTimeout = 30 * time.Second // Timeout for MCP client connection establishment - - // Context keys for client filtering in requests - // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). - // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. - // Request context filtering takes priority over client config - context can override client exclusions. - MCPContextKeyIncludeClients schemas.BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering - MCPContextKeyIncludeTools schemas.BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) ) // ============================================================================ @@ -110,7 +103,7 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider } } - manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, logger) + manager.toolsManager = NewToolsManager(config.ToolManagerConfig, manager, config.FetchNewRequestIDFunc, pluginPipelineProvider, releasePluginPipeline, oauth2Provider, logger) // Set up CodeMode if provided - inject dependencies after manager is created if codeMode != nil { @@ -149,7 +142,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider manager.clientMap[clientConfig.ID].State = schemas.MCPConnectionStateDisconnected } manager.mu.Unlock() - monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, clientConfig.IsPingAvailable, manager.logger) + isPingAvailable := true + if clientConfig.IsPingAvailable != nil { + isPingAvailable = *clientConfig.IsPingAvailable + } + monitor := NewClientHealthMonitor(manager, clientConfig.ID, DefaultHealthCheckInterval, isPingAvailable, manager.logger) manager.healthMonitorManager.StartMonitoring(monitor) } }(clientConfig) @@ -160,6 +157,13 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider return manager } +// SetPluginPipeline updates the plugin pipeline provider and release function on the manager's +// ToolsManager and CodeMode. Call this after attaching an externally-created MCPManager to a Bifrost +// instance so that nested tool calls in code mode can run through Bifrost's plugin hooks. +func (manager *MCPManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) { + manager.toolsManager.SetPluginPipeline(provider, release) +} + // AddToolsToRequest parses available MCP tools from the context and adds them to the request. // It respects context-based filtering for clients and tools, and returns the modified request // with tools attached. @@ -170,11 +174,11 @@ func NewMCPManager(ctx context.Context, config schemas.MCPConfig, oauth2Provider // // Returns: // - *schemas.BifrostRequest: The request with tools added -func (m *MCPManager) AddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { +func (m *MCPManager) AddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest { return m.toolsManager.ParseAndAddToolsToRequest(ctx, req) } -func (m *MCPManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { +func (m *MCPManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool { return m.toolsManager.GetAvailableTools(ctx) } diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index 8ba68e65ab..0527b1d46c 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -5,13 +5,17 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" "sync/atomic" "time" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/mcp/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -31,11 +35,15 @@ type PluginPipeline interface { // ToolsManager manages MCP tool execution and agent mode. type ToolsManager struct { - toolExecutionTimeout atomic.Value - maxAgentDepth atomic.Int32 - clientManager ClientManager - logger schemas.Logger - agentModeExecutor *AgentModeExecutor + toolExecutionTimeout atomic.Value + maxAgentDepth atomic.Int32 + disableAutoToolInject atomic.Bool + clientManager ClientManager + logger schemas.Logger + agentModeExecutor *AgentModeExecutor + + // OAuth2Provider for per-user OAuth token management + oauth2Provider schemas.OAuth2Provider // CodeMode implementation for code execution (Starlark by default) codeMode CodeMode @@ -73,6 +81,7 @@ func NewToolsManager( fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string, pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), + oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, ) *ToolsManager { return NewToolsManagerWithCodeMode( @@ -82,6 +91,7 @@ func NewToolsManager( pluginPipelineProvider, releasePluginPipeline, nil, // Use default code mode (will be set later via SetCodeMode) + oauth2Provider, logger, ) } @@ -106,6 +116,7 @@ func NewToolsManagerWithCodeMode( pluginPipelineProvider func() PluginPipeline, releasePluginPipeline func(pipeline PluginPipeline), codeMode CodeMode, + oauth2Provider schemas.OAuth2Provider, logger schemas.Logger, ) *ToolsManager { if config == nil { @@ -142,11 +153,13 @@ func NewToolsManagerWithCodeMode( codeMode: codeMode, logger: logger, agentModeExecutor: agentModeExecutor, + oauth2Provider: oauth2Provider, } // Initialize atomic values manager.toolExecutionTimeout.Store(config.ToolExecutionTimeout) manager.maxAgentDepth.Store(int32(config.MaxAgentDepth)) + manager.disableAutoToolInject.Store(config.DisableAutoToolInject) manager.logger.Info("%s tool manager initialized with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) return manager @@ -174,8 +187,20 @@ func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies { } } +// SetPluginPipeline updates the plugin pipeline provider and release function +// on both the ToolsManager and its CodeMode implementation. +// This is used when an externally-created MCPManager is attached to a Bifrost instance +// via SetMCPManager, so the CodeMode can route nested tool calls through Bifrost's plugin hooks. +func (m *ToolsManager) SetPluginPipeline(provider func() PluginPipeline, release func(PluginPipeline)) { + m.pluginPipelineProvider = provider + m.releasePluginPipeline = release + if m.codeMode != nil { + m.codeMode.SetDependencies(m.GetCodeModeDependencies()) + } +} + // GetAvailableTools returns the available tools for the given context. -func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool { +func (m *ToolsManager) GetAvailableTools(ctx *schemas.BifrostContext) []schemas.ChatTool { availableToolsPerClient := m.clientManager.GetToolPerClient(ctx) // Flatten tools from all clients into a single slice, avoiding duplicates var availableTools []schemas.ChatTool @@ -191,14 +216,14 @@ func (m *ToolsManager) GetAvailableTools(ctx context.Context) []schemas.ChatTool } if client.ExecutionConfig.IsCodeModeClient { includeCodeModeTools = true - } else { - // Add tools from this client, checking for duplicates - for _, tool := range clientTools { - if tool.Function != nil && tool.Function.Name != "" { - if !seenToolNames[tool.Function.Name] { - availableTools = append(availableTools, tool) - seenToolNames[tool.Function.Name] = true - } + } + // Add tools from this client, checking for duplicates + for _, tool := range clientTools { + if tool.Function != nil && tool.Function.Name != "" && !seenToolNames[tool.Function.Name] { + seenToolNames[tool.Function.Name] = true + schemas.AppendToContextList(ctx, schemas.BifrostContextKeyMCPAddedTools, tool.Function.Name) + if !client.ExecutionConfig.IsCodeModeClient { + availableTools = append(availableTools, tool) } } } @@ -288,12 +313,22 @@ func buildIntegrationDuplicateCheckMap(existingTools []schemas.ChatTool, integra // // Returns: // - *schemas.BifrostRequest: Bifrost request with MCP tools added -func (m *ToolsManager) ParseAndAddToolsToRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { +func (m *ToolsManager) ParseAndAddToolsToRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) *schemas.BifrostRequest { // MCP is only supported for chat and responses requests if req.ChatRequest == nil && req.ResponsesRequest == nil { return req } + // When auto tool injection is disabled, only inject tools if the request + // has explicit context filters set (e.g. via x-bf-mcp-include-tools header). + if m.disableAutoToolInject.Load() { + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) + includeClients := ctx.Value(schemas.MCPContextKeyIncludeClients) + if includeTools == nil && includeClients == nil { + return req + } + } + availableTools := m.GetAvailableTools(ctx) if len(availableTools) == 0 { @@ -541,9 +576,90 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall Name: originalMCPToolName, Arguments: arguments, }, + Header: utils.GetHeadersForToolExecution(ctx, client), } - if client.ExecutionConfig.Headers != nil { + // Handle per-user OAuth: inject user-specific Authorization header + if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth { + if m.oauth2Provider == nil { + return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured") + } + virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string) + userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string) + sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) + + // Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key + if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" { + userID = mcpUserID + } + + // Try identity-based token lookup first (works even without session token) + accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID) + if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) { + // Had session but token lookup failed with a real error (not just "not found") — return error + return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err) + } + if err != nil { + // No token found — user hasn't authenticated with this MCP server yet. + // In LLM gateway mode with no identity, we can't track who this user is, + // so an OAuth flow would produce an orphaned token. Return a clear error instead. + isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool) + if !isMCPGateway && userID == "" && virtualKeyID == "" { + return nil, "", "", fmt.Errorf( + "per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you", + client.ExecutionConfig.Name, + ) + } + + // Initiate OAuth flow to get a proper authorize URL with session tracking. + if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" { + return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name) + } + redirectURI := buildRedirectURIFromContext(ctx) + if redirectURI == "" { + return nil, "", "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context") + } + flowInitiation, sessionID, flowErr := m.oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI) + if flowErr != nil { + return nil, "", "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr) + } + return nil, "", "", &schemas.MCPUserOAuthRequiredError{ + MCPClientID: client.ExecutionConfig.ID, + MCPClientName: client.ExecutionConfig.Name, + AuthorizeURL: flowInitiation.AuthorizeURL, + SessionID: sessionID, + Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name), + } + } + + if client.Conn == nil { + // No persistent connection — create temporary connection with user's token + toolExecutionTimeout := m.toolExecutionTimeout.Load().(time.Duration) + toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout) + defer cancel() + + toolResponse, callErr := executeToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger) + if callErr != nil { + if toolCtx.Err() == context.DeadlineExceeded { + return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName) + } + m.logger.Error("%s Tool execution failed for %s via client %s: %v", MCPLogPrefix, toolName, client.ExecutionConfig.Name, callErr) + return nil, "", "", fmt.Errorf("MCP tool call failed: %v", callErr) + } + responseText := extractTextFromMCPResponse(toolResponse, toolName) + return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil + } + + // Persistent connection exists — use per-call headers + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + headers.Set("Authorization", "Bearer "+accessToken) + callRequest.Header = headers + } else if client.ExecutionConfig.Headers != nil { headers := make(http.Header) for key, value := range client.ExecutionConfig.Headers { headers.Add(key, value.GetValue()) @@ -660,17 +776,99 @@ func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) { m.maxAgentDepth.Store(int32(config.MaxAgentDepth)) } - // Update CodeMode configuration if present - if m.codeMode != nil && config.CodeModeBindingLevel != "" { + // Update CodeMode configuration — propagate whenever either field is set + if m.codeMode != nil && (config.CodeModeBindingLevel != "" || config.ToolExecutionTimeout > 0) { m.codeMode.UpdateConfig(&CodeModeConfig{ BindingLevel: config.CodeModeBindingLevel, ToolExecutionTimeout: config.ToolExecutionTimeout, }) } + m.disableAutoToolInject.Store(config.DisableAutoToolInject) + m.logger.Info("%s tool manager configuration updated with tool execution timeout: %v, max agent depth: %d, and code mode binding level: %s", MCPLogPrefix, config.ToolExecutionTimeout, config.MaxAgentDepth, config.CodeModeBindingLevel) } +// executeToolWithUserToken creates a temporary MCP connection using the user's +// OAuth access token, calls the specified tool, and closes the connection. +// This is used for per_user_oauth clients which have no persistent connection — +// each tool call gets its own short-lived connection authenticated with the +// requesting user's token. +// +// Parameters: +// - ctx: context with timeout for the entire operation +// - config: MCP client configuration (connection URL, name) +// - toolName: original MCP tool name to call +// - arguments: tool call arguments +// - accessToken: user's OAuth access token +// - logger: logger instance +// +// Returns: +// - *mcp.CallToolResult: tool execution result +// - error: any error during connection or execution +func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) { + if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" { + return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution") + } + + // Create HTTP transport with the user's Bearer token, preserving configured headers + headers := make(map[string]string) + if config.Headers != nil { + for key, value := range config.Headers { + headers[key] = value.GetValue() + } + } + headers["Authorization"] = "Bearer " + accessToken + httpTransport, err := transport.NewStreamableHTTP(config.ConnectionString.GetValue(), transport.WithHTTPHeaders(headers)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP transport: %w", err) + } + + // Create temporary MCP client + tempClient := client.NewClient(httpTransport) + if err := tempClient.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start temporary MCP connection: %w", err) + } + defer tempClient.Close() + + // Initialize MCP handshake + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{}, + ClientInfo: mcp.Implementation{ + Name: fmt.Sprintf("Bifrost-%s-user", config.Name), + Version: "1.0.0", + }, + }, + } + if _, err := tempClient.Initialize(ctx, initRequest); err != nil { + return nil, fmt.Errorf("failed to initialize temporary MCP connection: %w", err) + } + + // Call the tool + callRequest := mcp.CallToolRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodToolsCall), + }, + Params: mcp.CallToolParams{ + Name: toolName, + Arguments: arguments, + }, + } + return tempClient.CallTool(ctx, callRequest) +} + +// buildRedirectURIFromContext extracts the OAuth redirect URI from context. +// The URI is set by the HTTP middleware from the request's host. +func buildRedirectURIFromContext(ctx *schemas.BifrostContext) string { + if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" { + return uri + } + // Fallback — should not happen if middleware is configured correctly + return "" +} + // GetCodeModeBindingLevel returns the current code mode binding level. // This method is safe to call concurrently from multiple goroutines. func (m *ToolsManager) GetCodeModeBindingLevel() schemas.CodeModeBindingLevel { diff --git a/core/mcp/utils.go b/core/mcp/utils.go index d80ec17acc..479bc0bad8 100644 --- a/core/mcp/utils.go +++ b/core/mcp/utils.go @@ -65,7 +65,7 @@ func (m *MCPManager) GetToolPerClient(ctx context.Context) map[string][]schemas. var includeClients []string // Extract client filtering from request context - if existingIncludeClients, ok := ctx.Value(MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { + if existingIncludeClients, ok := ctx.Value(schemas.MCPContextKeyIncludeClients).([]string); ok && existingIncludeClients != nil { includeClients = existingIncludeClients } @@ -381,12 +381,12 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b // If ToolsToExecute is specified (not nil), apply filtering if config.ToolsToExecute != nil { // Handle empty array [] - means no tools are allowed - if len(config.ToolsToExecute) == 0 { + if config.ToolsToExecute.IsEmpty() { return true // No tools allowed } // Handle wildcard "*" - if present, all tools are allowed - if slices.Contains(config.ToolsToExecute, "*") { + if config.ToolsToExecute.IsUnrestricted() { return false // All tools allowed } @@ -396,7 +396,7 @@ func shouldSkipToolForConfig(toolName string, config *schemas.MCPClientConfig) b unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the allowed list - return !slices.Contains(config.ToolsToExecute, unprefixedToolName) // Tool not in allowed list + return !config.ToolsToExecute.Contains(unprefixedToolName) // Tool not in allowed list } return true // Tool is skipped (nil is treated as [] - no tools) @@ -413,12 +413,12 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // If ToolsToAutoExecute is specified (not nil), apply filtering if config.ToolsToAutoExecute != nil { // Handle empty array [] - means no tools are auto-executed - if len(config.ToolsToAutoExecute) == 0 { + if config.ToolsToAutoExecute.IsEmpty() { return false // No tools auto-executed } // Handle wildcard "*" - if present, all tools are auto-executed - if slices.Contains(config.ToolsToAutoExecute, "*") { + if config.ToolsToAutoExecute.IsUnrestricted() { return true // All tools auto-executed } @@ -428,7 +428,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { unprefixedToolName := stripClientPrefix(toolName, config.Name) // Check if specific tool is in the auto-execute list - return slices.Contains(config.ToolsToAutoExecute, unprefixedToolName) + return config.ToolsToAutoExecute.Contains(unprefixedToolName) } return false // Tool is not auto-executed (nil is treated as [] - no tools) @@ -439,7 +439,7 @@ func canAutoExecuteTool(toolName string, config *schemas.MCPClientConfig) bool { // Context filtering can only NARROW the tools available, NOT expand beyond client configuration. // This is checked AFTER client-level filtering (shouldSkipToolForConfig). func shouldSkipToolForRequest(ctx context.Context, clientName, toolName string) bool { - includeTools := ctx.Value(MCPContextKeyIncludeTools) + includeTools := ctx.Value(schemas.MCPContextKeyIncludeTools) if includeTools != nil { // Try []string first (preferred type) @@ -487,6 +487,28 @@ func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool, logger schemas.Logger) sch // object schemas to always have a properties field, even if empty properties = schemas.NewOrderedMap() } + + // Preserve MCP tool annotations if any are set. + // Clone bool pointers so Bifrost's copy is independent of the upstream mcp.Tool lifetime. + var annotations *schemas.MCPToolAnnotations + a := mcpTool.Annotations + if a.Title != "" || a.ReadOnlyHint != nil || a.DestructiveHint != nil || a.IdempotentHint != nil || a.OpenWorldHint != nil { + cloneBool := func(b *bool) *bool { + if b == nil { + return nil + } + v := *b + return &v + } + annotations = &schemas.MCPToolAnnotations{ + Title: a.Title, + ReadOnlyHint: cloneBool(a.ReadOnlyHint), + DestructiveHint: cloneBool(a.DestructiveHint), + IdempotentHint: cloneBool(a.IdempotentHint), + OpenWorldHint: cloneBool(a.OpenWorldHint), + } + } + return schemas.ChatTool{ Type: schemas.ChatToolTypeFunction, Function: &schemas.ChatToolFunction{ @@ -498,6 +520,7 @@ func convertMCPToolToBifrostSchema(mcpTool *mcp.Tool, logger schemas.Logger) sch Required: mcpTool.InputSchema.Required, }, }, + Annotations: annotations, } } @@ -754,6 +777,7 @@ func hasToolCallsForChatResponse(response *schemas.BifrostChatResponse) bool { if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { return true } + // Check if message has tool calls regardless of finish_reason. // Some providers (e.g. Gemini) return finish_reason "stop" even when tool calls are present, // so we cannot rely solely on finish_reason to detect tool calls. diff --git a/core/mcp/utils/utils.go b/core/mcp/utils/utils.go new file mode 100644 index 0000000000..500792a09f --- /dev/null +++ b/core/mcp/utils/utils.go @@ -0,0 +1,49 @@ +package utils + +import ( + "net/http" + + "github.com/maximhq/bifrost/core/schemas" +) + +// GetHeadersForToolExecution sets additional headers for tool execution. +// It returns the headers for the tool execution. +func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header { + if ctx == nil || client == nil || client.ExecutionConfig == nil { + return make(http.Header) + } + headers := make(http.Header) + if client.ExecutionConfig.Headers != nil { + for key, value := range client.ExecutionConfig.Headers { + headers.Add(key, value.GetValue()) + } + } + // Give priority to extra headers in the context + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyMCPExtraHeaders).(map[string][]string); ok { + filteredHeaders := make(http.Header) + for key, values := range extraHeaders { + if client.ExecutionConfig.AllowedExtraHeaders.IsAllowed(key) { + for i, value := range values { + if i == 0 { + filteredHeaders.Set(key, value) + } else { + filteredHeaders.Add(key, value) + } + } + } + } + // Add the filtered headers to the headers + if len(filteredHeaders) > 0 { + for k, values := range filteredHeaders { + for i, v := range values { + if i == 0 { + headers.Set(k, v) + } else { + headers.Add(k, v) + } + } + } + } + } + return headers +} diff --git a/core/mcp/utils_test.go b/core/mcp/utils_test.go index e74d9d7da1..1fba67db38 100644 --- a/core/mcp/utils_test.go +++ b/core/mcp/utils_test.go @@ -1,9 +1,13 @@ package mcp import ( + "encoding/json" "testing" "github.com/mark3labs/mcp-go/mcp" + "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestConvertMCPToolToBifrostSchema_EmptyParameters tests that tools with no parameters @@ -49,6 +53,68 @@ func TestConvertMCPToolToBifrostSchema_EmptyParameters(t *testing.T) { } } +// TestConvertMCPToolToBifrostSchema_WithAnnotations tests that MCP tool annotations +// are preserved on ChatTool.Annotations (not ChatToolFunction) and are absent from JSON. +func TestConvertMCPToolToBifrostSchema_WithAnnotations(t *testing.T) { + readOnly := true + destructive := false + + mcpTool := &mcp.Tool{ + Name: "read_resource", + Description: "Reads a resource", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{}, + }, + Annotations: mcp.ToolAnnotation{ + Title: "Resource Reader", + ReadOnlyHint: &readOnly, + DestructiveHint: &destructive, + IdempotentHint: schemas.Ptr(true), + }, + } + + bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger) + + // Annotations must be on ChatTool, not buried in Function + require.NotNil(t, bifrostTool.Annotations, "Annotations should be set on ChatTool") + assert.Equal(t, "Resource Reader", bifrostTool.Annotations.Title) + require.NotNil(t, bifrostTool.Annotations.ReadOnlyHint) + assert.True(t, *bifrostTool.Annotations.ReadOnlyHint) + require.NotNil(t, bifrostTool.Annotations.DestructiveHint) + assert.False(t, *bifrostTool.Annotations.DestructiveHint) + require.NotNil(t, bifrostTool.Annotations.IdempotentHint) + assert.True(t, *bifrostTool.Annotations.IdempotentHint) + assert.Nil(t, bifrostTool.Annotations.OpenWorldHint) + + // The JSON sent to providers must not contain annotations + toolJSON, err := json.Marshal(bifrostTool) + require.NoError(t, err) + s := string(toolJSON) + assert.NotContains(t, s, "annotations", "annotations must be absent from provider JSON") + assert.NotContains(t, s, "readOnlyHint", "readOnlyHint must be absent from provider JSON") + assert.NotContains(t, s, "Resource Reader", "annotation title must be absent from provider JSON") +} + +// TestConvertMCPToolToBifrostSchema_NilAnnotationsWhenAllZero verifies the nil guard: +// when all annotation fields are zero-valued, ChatTool.Annotations must remain nil. +func TestConvertMCPToolToBifrostSchema_NilAnnotationsWhenAllZero(t *testing.T) { + mcpTool := &mcp.Tool{ + Name: "no_hints_tool", + Description: "A tool with no annotation hints", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{}, + }, + Annotations: mcp.ToolAnnotation{}, // All zero values — Title empty, all hints nil + } + + bifrostTool := convertMCPToolToBifrostSchema(mcpTool, defaultLogger) + + assert.Nil(t, bifrostTool.Annotations, + "Annotations should be nil when all MCP annotation fields are zero") +} + // TestConvertMCPToolToBifrostSchema_WithParameters tests the normal case with parameters func TestConvertMCPToolToBifrostSchema_WithParameters(t *testing.T) { // Create a tool with parameters diff --git a/core/providers/anthropic/anthropic.go b/core/providers/anthropic/anthropic.go index e5d62fe87c..0fc6073ced 100644 --- a/core/providers/anthropic/anthropic.go +++ b/core/providers/anthropic/anthropic.go @@ -173,7 +173,7 @@ func extractAnthropicResponsesUsageFromPrefetch(data []byte) *schemas.ResponsesR // Returns the response body or an error if the request fails. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, latency, nil) — callers must check the context flag. -func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, requestType schemas.RequestType) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -208,7 +208,7 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, requestClient := provider.client responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) - isCountTokens := meta != nil && meta.RequestType == schemas.CountTokensRequest + isCountTokens := requestType == schemas.CountTokensRequest // CountTokens responses are always tiny — skip streaming client so the response // is buffered normally (same approach as OpenAI and Gemini count_tokens handlers). if responseThreshold > 0 && !isCountTokens { @@ -233,20 +233,20 @@ func (provider *AnthropicProvider) completeRequest(ctx *schemas.BifrostContext, if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) - return nil, latency, providerResponseHeaders, parseAnthropicError(resp, meta) + return nil, latency, providerResponseHeaders, parseAnthropicError(resp) } // CountTokens uses buffered response (streaming skipped above) — decode directly. if isCountTokens { body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return body, latency, providerResponseHeaders, nil } // Delegate large response detection + normal buffered path to shared utility - body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { return nil, latency, providerResponseHeaders, respErr } @@ -290,10 +290,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseAnthropicError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseAnthropicError(resp) } // Parse Anthropic's response @@ -304,7 +301,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext, } // Create final response - response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered) + response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -355,18 +352,13 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToAnthropicTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use struct directly for JSON marshaling (no beta headers for text completion) - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TextCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/complete", schemas.TextCompletionRequest), key.Value.GetValue(), schemas.TextCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -379,9 +371,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -400,9 +389,6 @@ func (provider *AnthropicProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostTextCompletionResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -444,18 +430,35 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k } AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } + // On the raw-body passthrough path, the typed-struct StripUnsupportedAnthropicFields + // was not invoked. Apply the JSON-level sanitizer for behavioural parity so + // unsupported request-level and tool-level fields don't leak to providers that + // would reject them. + if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody { + // Feature gating keyed to schemas.Anthropic (not provider.GetProviderKey()) + // so custom Anthropic aliases get the same feature lookup as the typed + // path above (line 445), keeping raw and typed behavior in lockstep. + sanitized, rawErr := stripUnsupportedFieldsFromRawBody(jsonData, schemas.Anthropic, request.Model) + if rawErr != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, rawErr) + } + jsonData = sanitized + // Auto-inject matching anthropic-beta headers for fields the sanitizer + // preserved. Probe-unmarshal reuses the typed path's header walker so + // the two paths stay in lockstep. + var probe AnthropicMessageRequest + if err := schemas.Unmarshal(jsonData, &probe); err == nil { + AddMissingBetaHeadersToContext(ctx, &probe, schemas.Anthropic) + } + } + // Use struct directly for JSON marshaling - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, provider.buildRequestURL(ctx, "/v1/messages", schemas.ChatCompletionRequest), key.Value.GetValue(), schemas.ChatCompletionRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -468,9 +471,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -489,9 +489,6 @@ func (provider *AnthropicProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := response.ToBifrostChatResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -528,12 +525,30 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont anthropicReq.Stream = schemas.Ptr(true) AddMissingBetaHeadersToContext(ctx, anthropicReq, schemas.Anthropic) return anthropicReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } + // On the raw-body passthrough path, the typed-struct StripUnsupportedAnthropicFields + // was not invoked. Apply the JSON-level sanitizer for behavioural parity. + if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody { + // Feature gating keyed to schemas.Anthropic (not provider.GetProviderKey()) + // to keep raw and typed paths in lockstep on custom aliases — mirrors + // the typed path's hardcoded schemas.Anthropic at line 548. + sanitized, rawErr := stripUnsupportedFieldsFromRawBody(jsonData, schemas.Anthropic, request.Model) + if rawErr != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, rawErr) + } + jsonData = sanitized + // Auto-inject matching anthropic-beta headers for fields the sanitizer + // preserved. Probe-unmarshal reuses the typed path's header walker. + var probe AnthropicMessageRequest + if err := schemas.Unmarshal(jsonData, &probe); err == nil { + AddMissingBetaHeadersToContext(ctx, &probe, schemas.Anthropic) + } + } + // Prepare Anthropic headers headers := map[string]string{ "Content-Type": "application/json", @@ -563,11 +578,6 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx *schemas.BifrostCont postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } @@ -587,7 +597,6 @@ func HandleAnthropicChatCompletionStreaming( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -634,9 +643,9 @@ func HandleAnthropicChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -645,7 +654,7 @@ func HandleAnthropicChatCompletionStreaming( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -660,15 +669,12 @@ func HandleAnthropicChatCompletionStreaming( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { - model := "unknown" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -678,7 +684,6 @@ func HandleAnthropicChatCompletionStreaming( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -732,7 +737,7 @@ func HandleAnthropicChatCompletionStreaming( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -791,7 +796,6 @@ func HandleAnthropicChatCompletionStreaming( } } if event.Message != nil { - // Handle different event types modelName = event.Message.Model } @@ -840,11 +844,8 @@ func HandleAnthropicChatCompletionStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -868,22 +869,14 @@ func HandleAnthropicChatCompletionStreaming( response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream(ctx, structuredOutputToolName, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger) break } if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -910,7 +903,7 @@ func HandleAnthropicChatCompletionStreaming( usage.PromptTokens = usage.PromptTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens usage.TotalTokens = usage.TotalTokens + usage.PromptTokensDetails.CachedReadTokens + usage.PromptTokensDetails.CachedWriteTokens } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, modelName) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, modelName, 0) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -939,16 +932,12 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, nil) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages", schemas.ResponsesRequest), key.Value.GetValue(), schemas.ResponsesRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -966,9 +955,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc Model: request.Model, Usage: extractAnthropicResponsesUsageFromPrefetch([]byte(preview)), ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -988,9 +974,6 @@ func (provider *AnthropicProvider) Responses(ctx *schemas.BifrostContext, key sc bifrostResponse := response.ToBifrostResponsesResponse(ctx) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1014,7 +997,7 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, } // Convert to Anthropic format using the centralized converter - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), true, nil) + jsonBody, err := getRequestBodyForResponses(ctx, request, true, nil) if err != nil { return nil, err } @@ -1047,11 +1030,6 @@ func (provider *AnthropicProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner, nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } @@ -1071,7 +1049,6 @@ func HandleAnthropicResponsesStream( postHookRunner schemas.PostHookRunner, postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse, logger schemas.Logger, - meta *providerUtils.RequestMetadata, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -1120,9 +1097,9 @@ func HandleAnthropicResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1131,7 +1108,7 @@ func HandleAnthropicResponsesStream( // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp, meta), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseAnthropicError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1146,15 +1123,12 @@ func HandleAnthropicResponsesStream( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { - model := "" - if meta != nil { - model = meta.Model - } if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1164,7 +1138,6 @@ func HandleAnthropicResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1216,7 +1189,7 @@ func HandleAnthropicResponsesStream( if readErr != io.EOF { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, modelName, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1286,11 +1259,6 @@ func HandleAnthropicResponsesStream( ctx.SetValue(schemas.BifrostContextKeyHasEmittedMessageDelta, true) } if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - } // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { return @@ -1307,12 +1275,9 @@ func HandleAnthropicResponsesStream( Type: schemas.ResponsesStreamResponseType(eventType), SequenceNumber: chunkIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), - RawResponse: eventData, + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), + RawResponse: eventData, }, } lastChunkTime = time.Now() @@ -1326,11 +1291,8 @@ func HandleAnthropicResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: modelName, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { response = postResponseConverter(response) @@ -1384,7 +1346,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if len(request.Requests) == 0 { - return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("requests array is required for Anthropic batch API", nil) } // Create request @@ -1422,7 +1384,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key jsonData, err := providerUtils.MarshalSorted(anthropicReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } usedLargePayloadBody := setAnthropicRequestBody(ctx, req, jsonData) @@ -1442,12 +1404,12 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var anthropicResp AnthropicBatchResponse @@ -1456,7 +1418,7 @@ func (provider *AnthropicProvider) BatchCreate(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return anthropicResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -1472,7 +1434,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Initialize serial pagination helper (Anthropic uses AfterID for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.AfterID, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -1483,10 +1445,6 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -1535,12 +1493,12 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.BatchListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicBatchListResponse @@ -1553,7 +1511,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(anthropicResp.Data)) var lastBatchID string for _, batch := range anthropicResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, false, false, nil, nil)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, false, false, nil, nil)) lastBatchID = batch.ID } @@ -1567,9 +1525,7 @@ func (provider *AnthropicProvider) BatchList(ctx *schemas.BifrostContext, keys [ Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -1587,7 +1543,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1628,7 +1584,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1640,7 +1596,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1658,8 +1614,7 @@ func (provider *AnthropicProvider) BatchRetrieve(ctx *schemas.BifrostContext, ke fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := anthropicResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := anthropicResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -1674,7 +1629,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // batch id is required if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1711,7 +1666,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1723,7 +1678,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1746,9 +1701,7 @@ func (provider *AnthropicProvider) BatchCancel(ctx *schemas.BifrostContext, keys Object: anthropicResp.Type, Status: ToBifrostBatchStatus(anthropicResp.ProcessingStatus), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1791,7 +1744,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.Anthropic) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } providerName := provider.GetProviderKey() @@ -1825,7 +1778,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -1837,7 +1790,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -1879,9 +1832,7 @@ func (provider *AnthropicProvider) BatchResults(ctx *schemas.BifrostContext, key BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1964,7 +1915,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -1978,14 +1929,14 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2017,12 +1968,12 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileUploadRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var anthropicResp AnthropicFileResponse @@ -2033,7 +1984,7 @@ func (provider *AnthropicProvider) FileUpload(ctx *schemas.BifrostContext, key s return nil, bifrostErr } - return anthropicResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided keys and aggregates results. @@ -2051,7 +2002,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2062,10 +2013,6 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2111,12 +2058,12 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseAnthropicError(resp, schemas.FileListRequest, providerName, "") + return nil, parseAnthropicError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var anthropicResp AnthropicFileListResponse @@ -2151,9 +2098,7 @@ func (provider *AnthropicProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2172,7 +2117,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2213,7 +2158,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2225,7 +2170,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2243,7 +2188,7 @@ func (provider *AnthropicProvider) FileRetrieve(ctx *schemas.BifrostContext, key fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - return anthropicResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return anthropicResp.ToBifrostFileRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } return nil, lastErr @@ -2258,7 +2203,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2295,7 +2240,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2312,9 +2257,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2324,7 +2267,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2347,9 +2290,7 @@ func (provider *AnthropicProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: anthropicResp.Type == "file_deleted", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2377,7 +2318,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -2409,7 +2350,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseAnthropicError(resp, schemas.FileContentRequest, providerName, "") + lastErr = parseAnthropicError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2421,7 +2362,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2441,9 +2382,7 @@ func (provider *AnthropicProvider) FileContent(ctx *schemas.BifrostContext, keys Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2456,16 +2395,12 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key if err := providerUtils.CheckOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.CountTokensRequest); err != nil { return nil, err } - jsonBody, err := getRequestBodyForResponses(ctx, request, provider.GetProviderKey(), false, []string{"max_tokens", "temperature"}) + jsonBody, err := getRequestBodyForResponses(ctx, request, false, []string{"max_tokens", "temperature"}) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }) + responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v1/messages/count_tokens", schemas.CountTokensRequest), key.Value.GetValue(), schemas.CountTokensRequest) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2489,9 +2424,6 @@ func (provider *AnthropicProvider) CountTokens(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) response.Model = request.Model - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2626,7 +2558,7 @@ func (provider *AnthropicProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { @@ -2641,9 +2573,6 @@ func (provider *AnthropicProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2707,9 +2636,9 @@ func (provider *AnthropicProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2720,7 +2649,6 @@ func (provider *AnthropicProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -2732,11 +2660,7 @@ func (provider *AnthropicProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2745,11 +2669,12 @@ func (provider *AnthropicProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -2798,7 +2723,7 @@ func (provider *AnthropicProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/anthropic/anthropic_test.go b/core/providers/anthropic/anthropic_test.go index d64b10aa82..6cb05f8c8c 100644 --- a/core/providers/anthropic/anthropic_test.go +++ b/core/providers/anthropic/anthropic_test.go @@ -72,7 +72,9 @@ func TestAnthropic(t *testing.T) { PassthroughAPI: true, Compaction: true, InterleavedThinking: true, - FastMode: false, // Enable when test API key has Opus 4.6 access + FastMode: false, // Enable when test API key has Opus 4.6 access + EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Anthropic) + ServerToolsViaOpenAIEndpoint: true, // web_search / web_fetch / code_execution via /v1/chat/completions }, } diff --git a/core/providers/anthropic/batch.go b/core/providers/anthropic/batch.go index 405738330c..ac4b0940c4 100644 --- a/core/providers/anthropic/batch.go +++ b/core/providers/anthropic/batch.go @@ -3,9 +3,7 @@ package anthropic import ( "time" - providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "github.com/valyala/fasthttp" ) // Anthropic Batch API Types @@ -129,7 +127,7 @@ func ToBifrostObjectType(anthropicType string) string { } // ToBifrostBatchCreateResponse converts Anthropic batch response to Bifrost batch create response. -func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { expiresAt := parseAnthropicTimestamp(r.ExpiresAt) resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, @@ -140,9 +138,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExpiresAt: &expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -170,7 +166,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(providerName schem } // ToBifrostBatchRetrieveResponse converts Anthropic batch response to Bifrost batch retrieve response. -func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: ToBifrostObjectType(r.Type), @@ -179,9 +175,7 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch ResultsURL: r.ResultsURL, CreatedAt: parseAnthropicTimestamp(r.CreatedAt), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -228,26 +222,6 @@ func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(providerName sch return resp } -// ParseAnthropicError parses Anthropic error responses for batch operations. -func ParseAnthropicError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { - var errorResp AnthropicError - bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) - if errorResp.Error != nil { - if errorResp.Error.Type != "" { - bifrostErr.Error.Type = &errorResp.Error.Type - } - if errorResp.Error.Message != "" { - bifrostErr.Error.Message = errorResp.Error.Message - } - } - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - } - return bifrostErr -} - // ToAnthropicBatchCreateResponse converts a Bifrost batch create response to Anthropic format. func ToAnthropicBatchCreateResponse(resp *schemas.BifrostBatchCreateResponse) *AnthropicBatchResponse { result := &AnthropicBatchResponse{ diff --git a/core/providers/anthropic/chat.go b/core/providers/anthropic/chat.go index 93c0e7c1d0..5cf221b84e 100644 --- a/core/providers/anthropic/chat.go +++ b/core/providers/anthropic/chat.go @@ -3,6 +3,7 @@ package anthropic import ( "encoding/json" "fmt" + "strings" "time" "github.com/bytedance/sonic" @@ -10,6 +11,231 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// convertFunctionToolToAnthropic turns an OpenAI-style function tool +// (schemas.ChatTool with non-nil Function) into an AnthropicTool. +// Factored out from ToAnthropicChatRequest's tool loop so the loop can branch +// cleanly between function and server-tool shapes. +func convertFunctionToolToAnthropic(tool schemas.ChatTool) AnthropicTool { + anthropicTool := AnthropicTool{ + Name: tool.Function.Name, + } + if tool.Function.Description != nil { + anthropicTool.Description = tool.Function.Description + } + + // Convert function parameters to input_schema + if tool.Function.Parameters != nil && (tool.Function.Parameters.Type != "" || tool.Function.Parameters.Properties != nil) { + anthropicTool.InputSchema = &schemas.ToolFunctionParameters{ + Type: tool.Function.Parameters.Type, + Description: tool.Function.Parameters.Description, + Properties: tool.Function.Parameters.Properties, + Required: tool.Function.Parameters.Required, + Enum: tool.Function.Parameters.Enum, + AdditionalProperties: tool.Function.Parameters.AdditionalProperties, + Defs: tool.Function.Parameters.Defs, + Definitions: tool.Function.Parameters.Definitions, + Ref: tool.Function.Parameters.Ref, + Items: tool.Function.Parameters.Items, + MinItems: tool.Function.Parameters.MinItems, + MaxItems: tool.Function.Parameters.MaxItems, + AnyOf: tool.Function.Parameters.AnyOf, + OneOf: tool.Function.Parameters.OneOf, + AllOf: tool.Function.Parameters.AllOf, + Format: tool.Function.Parameters.Format, + Pattern: tool.Function.Parameters.Pattern, + MinLength: tool.Function.Parameters.MinLength, + MaxLength: tool.Function.Parameters.MaxLength, + Minimum: tool.Function.Parameters.Minimum, + Maximum: tool.Function.Parameters.Maximum, + Title: tool.Function.Parameters.Title, + Default: tool.Function.Parameters.Default, + Nullable: tool.Function.Parameters.Nullable, + } + } + + if anthropicTool.InputSchema != nil { + anthropicTool.InputSchema = anthropicTool.InputSchema.Normalized() + } + + if tool.CacheControl != nil { + anthropicTool.CacheControl = tool.CacheControl + } + if tool.DeferLoading != nil { + anthropicTool.DeferLoading = tool.DeferLoading + } + if len(tool.AllowedCallers) > 0 { + anthropicTool.AllowedCallers = tool.AllowedCallers + } + if len(tool.InputExamples) > 0 { + anthropicTool.InputExamples = make([]AnthropicToolInputExample, len(tool.InputExamples)) + for i, ex := range tool.InputExamples { + anthropicTool.InputExamples[i] = AnthropicToolInputExample{ + Input: ex.Input, + Description: ex.Description, + } + } + } + if tool.EagerInputStreaming != nil { + anthropicTool.EagerInputStreaming = tool.EagerInputStreaming + } + // ChatToolFunction.Strict is the canonical neutral slot for Anthropic's strict. + if tool.Function.Strict != nil { + anthropicTool.Strict = tool.Function.Strict + } + return anthropicTool +} + +// convertServerToolToAnthropic reconstructs an AnthropicTool from the +// server-tool shape of a schemas.ChatTool (Function=nil, Name+Type+variant +// fields populated). Returns (tool, true) when Type looks like a known +// server-tool; (zero, false) when it doesn't, so the caller can drop it +// cleanly rather than forward a malformed tool. +// +// Supported type prefixes: +// - web_search_* → AnthropicToolWebSearch +// - web_fetch_* → AnthropicToolWebFetch +// - computer_* → AnthropicToolComputerUse +// - text_editor_* → AnthropicToolTextEditor +// - mcp_toolset → AnthropicMCPToolsetTool (via MCPToolset pointer) +// +// bash_*, memory_*, code_execution_*, and tool_search_* carry no variant +// config — their Type + Name alone are enough, handled in the default branch. +func convertServerToolToAnthropic(tool schemas.ChatTool) (AnthropicTool, bool) { + typeStr := string(tool.Type) + if typeStr == "" { + return AnthropicTool{}, false + } + + // mcp_toolset is serialized via a dedicated embedded type (AnthropicMCPToolsetTool) + // and carries its identity in MCPServerName, not Name — handle before the + // generic Name guard below. + if typeStr == "mcp_toolset" { + if tool.MCPServerName == "" { + return AnthropicTool{}, false + } + toolset := &AnthropicMCPToolsetTool{ + Type: "mcp_toolset", + MCPServerName: tool.MCPServerName, + DefaultConfig: convertMCPToolsetConfig(tool.DefaultConfig), + Configs: convertMCPToolsetConfigMap(tool.Configs), + CacheControl: tool.CacheControl, + } + return AnthropicTool{MCPToolset: toolset}, true + } + + // Remaining server tools (web_search, web_fetch, computer, text_editor, etc.) + // identify themselves via Name. + if tool.Name == "" { + return AnthropicTool{}, false + } + + atype := AnthropicToolType(typeStr) + anthropicTool := AnthropicTool{ + Name: tool.Name, + Type: &atype, + CacheControl: tool.CacheControl, + DeferLoading: tool.DeferLoading, + AllowedCallers: tool.AllowedCallers, + EagerInputStreaming: tool.EagerInputStreaming, + } + if len(tool.InputExamples) > 0 { + anthropicTool.InputExamples = make([]AnthropicToolInputExample, len(tool.InputExamples)) + for i, ex := range tool.InputExamples { + anthropicTool.InputExamples[i] = AnthropicToolInputExample{ + Input: ex.Input, + Description: ex.Description, + } + } + } + + switch { + case strings.HasPrefix(typeStr, "web_search_"): + anthropicTool.AnthropicToolWebSearch = &AnthropicToolWebSearch{ + MaxUses: tool.MaxUses, + AllowedDomains: tool.AllowedDomains, + BlockedDomains: tool.BlockedDomains, + UserLocation: convertUserLocation(tool.UserLocation), + } + case strings.HasPrefix(typeStr, "web_fetch_"): + anthropicTool.AnthropicToolWebFetch = &AnthropicToolWebFetch{ + MaxUses: tool.MaxUses, + AllowedDomains: tool.AllowedDomains, + BlockedDomains: tool.BlockedDomains, + MaxContentTokens: tool.MaxContentTokens, + Citations: convertCitationsConfig(tool.Citations), + UseCache: tool.UseCache, + } + case strings.HasPrefix(typeStr, "computer_"): + anthropicTool.AnthropicToolComputerUse = &AnthropicToolComputerUse{ + DisplayWidthPx: tool.DisplayWidthPx, + DisplayHeightPx: tool.DisplayHeightPx, + DisplayNumber: tool.DisplayNumber, + EnableZoom: tool.EnableZoom, + } + case strings.HasPrefix(typeStr, "text_editor_"): + anthropicTool.AnthropicToolTextEditor = &AnthropicToolTextEditor{ + MaxCharacters: tool.MaxCharacters, + } + case strings.HasPrefix(typeStr, "bash_"), + strings.HasPrefix(typeStr, "memory_"), + strings.HasPrefix(typeStr, "code_execution_"), + strings.HasPrefix(typeStr, "tool_search_tool_"): + // No variant-specific config — Type + Name alone. + default: + // Unknown type — pass through Type + Name and let Anthropic reject + // if it's truly invalid. This keeps forward-compat for new tool + // versions that aren't yet known to Bifrost. + } + return anthropicTool, true +} + +// convertUserLocation mirrors schemas.ChatToolUserLocation onto +// AnthropicToolWebSearchUserLocation. +func convertUserLocation(loc *schemas.ChatToolUserLocation) *AnthropicToolWebSearchUserLocation { + if loc == nil { + return nil + } + return &AnthropicToolWebSearchUserLocation{ + Type: loc.Type, + City: loc.City, + Region: loc.Region, + Country: loc.Country, + Timezone: loc.Timezone, + } +} + +// convertCitationsConfig mirrors the request-side citations config +// ({"enabled": true/false}) onto AnthropicCitations' request form. +func convertCitationsConfig(c *schemas.ChatToolCitationsConfig) *AnthropicCitations { + if c == nil { + return nil + } + return &AnthropicCitations{Config: &schemas.Citations{Enabled: c.Enabled}} +} + +// convertMCPToolsetConfig mirrors a single mcp_toolset config. +func convertMCPToolsetConfig(c *schemas.ChatMCPToolsetConfig) *AnthropicMCPToolsetConfig { + if c == nil { + return nil + } + return &AnthropicMCPToolsetConfig{ + Enabled: c.Enabled, + DeferLoading: c.DeferLoading, + } +} + +// convertMCPToolsetConfigMap mirrors the per-tool mcp_toolset configs map. +func convertMCPToolsetConfigMap(m map[string]*schemas.ChatMCPToolsetConfig) map[string]*AnthropicMCPToolsetConfig { + if len(m) == 0 { + return nil + } + out := make(map[string]*AnthropicMCPToolsetConfig, len(m)) + for k, v := range m { + out[k] = convertMCPToolsetConfig(v) + } + return out +} + // ToAnthropicChatRequest converts a Bifrost request to Anthropic format // This is the reverse of ConvertChatRequestToBifrost for provider-side usage func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) (*AnthropicMessageRequest, error) { @@ -30,29 +256,59 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif anthropicReq.MaxTokens = *bifrostReq.Params.MaxCompletionTokens } - // Anthropic doesn't allow both temperature and top_p to be specified - // If both are present, prefer temperature (more commonly used) - if bifrostReq.Params.Temperature != nil { - anthropicReq.Temperature = bifrostReq.Params.Temperature - } else if bifrostReq.Params.TopP != nil { - anthropicReq.TopP = bifrostReq.Params.TopP + // Opus 4.7+ rejects temperature, top_p, and top_k with a 400 error. + if !IsOpus47(bifrostReq.Model) { + // Anthropic doesn't allow both temperature and top_p to be specified. + // If both are present, prefer temperature (more commonly used). + if bifrostReq.Params.Temperature != nil { + anthropicReq.Temperature = bifrostReq.Params.Temperature + } else if bifrostReq.Params.TopP != nil { + anthropicReq.TopP = bifrostReq.Params.TopP + } } anthropicReq.StopSequences = bifrostReq.Params.Stop - topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]) - if ok { + + // TopK — prefer the promoted neutral field; fall back to ExtraParams. + // Opus 4.7+ rejects top_k with a 400 error. + if bifrostReq.Params.TopK != nil { + if !IsOpus47(bifrostReq.Model) { + anthropicReq.TopK = bifrostReq.Params.TopK + } + } else if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { delete(anthropicReq.ExtraParams, "top_k") - anthropicReq.TopK = topK + if !IsOpus47(bifrostReq.Model) { + anthropicReq.TopK = topK + } } - if speed, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["speed"]); ok { + + // Speed — prefer neutral field, then ExtraParams. + if bifrostReq.Params.Speed != nil { + anthropicReq.Speed = bifrostReq.Params.Speed + } else if speed, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["speed"]); ok { delete(anthropicReq.ExtraParams, "speed") anthropicReq.Speed = speed } - // extract inference_geo and context management - if inferenceGeo, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["inference_geo"]); ok { + + // InferenceGeo — prefer neutral field, then ExtraParams. + if bifrostReq.Params.InferenceGeo != nil { + anthropicReq.InferenceGeo = bifrostReq.Params.InferenceGeo + } else if inferenceGeo, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["inference_geo"]); ok { delete(anthropicReq.ExtraParams, "inference_geo") anthropicReq.InferenceGeo = inferenceGeo } - if cmVal := bifrostReq.Params.ExtraParams["context_management"]; cmVal != nil { + + // ContextManagement — the neutral type is json.RawMessage; decode to + // the Anthropic-shape ContextManagement. Fall back to ExtraParams + // (legacy map-valued or typed-pointer paths) if the raw is empty. + // Surface decode errors on the typed path so callers get immediate + // feedback on malformed config instead of a silent drop. + if len(bifrostReq.Params.ContextManagement) > 0 { + var cm ContextManagement + if err := sonic.Unmarshal(bifrostReq.Params.ContextManagement, &cm); err != nil { + return nil, fmt.Errorf("context_management: failed to parse: %w", err) + } + anthropicReq.ContextManagement = &cm + } else if cmVal := bifrostReq.Params.ExtraParams["context_management"]; cmVal != nil { if cm, ok := cmVal.(*ContextManagement); ok && cm != nil { delete(anthropicReq.ExtraParams, "context_management") anthropicReq.ContextManagement = cm @@ -64,6 +320,65 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif } } } + + // Container — map the neutral ChatContainer union onto the Anthropic + // AnthropicContainer union. Both follow the string-or-object pattern. + if bifrostReq.Params.Container != nil { + c := &AnthropicContainer{} + if bifrostReq.Params.Container.ContainerStr != nil { + c.ContainerStr = bifrostReq.Params.Container.ContainerStr + } else if bifrostReq.Params.Container.ContainerObject != nil { + obj := &AnthropicContainerObject{ + ID: bifrostReq.Params.Container.ContainerObject.ID, + } + if len(bifrostReq.Params.Container.ContainerObject.Skills) > 0 { + obj.Skills = make([]AnthropicContainerSkill, len(bifrostReq.Params.Container.ContainerObject.Skills)) + for i, sk := range bifrostReq.Params.Container.ContainerObject.Skills { + obj.Skills[i] = AnthropicContainerSkill{ + SkillID: sk.SkillID, + Type: sk.Type, + Version: sk.Version, + } + } + } + c.ContainerObject = obj + } + anthropicReq.Container = c + } + + // Top-level CacheControl on the request. + if bifrostReq.Params.CacheControl != nil { + anthropicReq.CacheControl = bifrostReq.Params.CacheControl + } + + // TaskBudget — maps onto output_config.task_budget. If an OutputConfig + // already exists (e.g. from structured outputs), attach the budget to + // it; otherwise create one. + if bifrostReq.Params.TaskBudget != nil { + tb := &AnthropicTaskBudget{ + Type: bifrostReq.Params.TaskBudget.Type, + Total: bifrostReq.Params.TaskBudget.Total, + Remaining: bifrostReq.Params.TaskBudget.Remaining, + } + if anthropicReq.OutputConfig == nil { + anthropicReq.OutputConfig = &AnthropicOutputConfig{} + } + anthropicReq.OutputConfig.TaskBudget = tb + } + + // MCPServers — mirror the neutral ChatMCPServer[] to AnthropicMCPServerV2[]. + if len(bifrostReq.Params.MCPServers) > 0 { + servers := make([]AnthropicMCPServerV2, len(bifrostReq.Params.MCPServers)) + for i, s := range bifrostReq.Params.MCPServers { + servers[i] = AnthropicMCPServerV2{ + Type: s.Type, + URL: s.URL, + Name: s.Name, + AuthorizationToken: s.AuthorizationToken, + } + } + anthropicReq.MCPServers = servers + } if bifrostReq.Params.ResponseFormat != nil { // Vertex doesn't support native structured outputs, so convert to tool if bifrostReq.Provider == schemas.Vertex { @@ -87,65 +402,32 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif } } - // Convert tools + // Convert tools. Three neutral ChatTool shapes are supported: + // (1) Function tool (tool.Function != nil) — existing path. + // (2) Anthropic server tool (tool.Function == nil, Type is a + // server-tool version string, Name populated at top level) — + // new path handled by convertServerToolToAnthropic. + // (3) Custom tool (tool.Custom != nil) — not currently forwarded + // to Anthropic; skipped. if bifrostReq.Params.Tools != nil { - tools := make([]AnthropicTool, 0, len(bifrostReq.Params.Tools)) - for _, tool := range bifrostReq.Params.Tools { - if tool.Function == nil { + // Strip server tools the target provider doesn't support per + // ProviderFeatures (e.g. web_search on Vertex's non-supporting + // model variants, or MCP on Bedrock when this converter is used + // by non-Bedrock providers). Function/custom tools are always + // kept. The dropped set is discarded — "silent strip + continue" + // policy per user direction. See Bedrock's convertToolConfig for + // the direct-Bedrock-path equivalent. + filtered, _ := ValidateChatToolsForProvider(bifrostReq.Params.Tools, bifrostReq.Provider) + tools := make([]AnthropicTool, 0, len(filtered)) + for _, tool := range filtered { + if tool.Function != nil { + tools = append(tools, convertFunctionToolToAnthropic(tool)) continue } - anthropicTool := AnthropicTool{ - Name: tool.Function.Name, + // Non-function tool: attempt server-tool reconstruction. + if converted, ok := convertServerToolToAnthropic(tool); ok { + tools = append(tools, converted) } - if tool.Function.Description != nil { - anthropicTool.Description = tool.Function.Description - } - - // Convert function parameters to input_schema - if tool.Function.Parameters != nil && (tool.Function.Parameters.Type != "" || tool.Function.Parameters.Properties != nil) { - anthropicTool.InputSchema = &schemas.ToolFunctionParameters{ - Type: tool.Function.Parameters.Type, - Description: tool.Function.Parameters.Description, - Properties: tool.Function.Parameters.Properties, - Required: tool.Function.Parameters.Required, - Enum: tool.Function.Parameters.Enum, - AdditionalProperties: tool.Function.Parameters.AdditionalProperties, - // JSON Schema definition fields - Defs: tool.Function.Parameters.Defs, - Definitions: tool.Function.Parameters.Definitions, - Ref: tool.Function.Parameters.Ref, - // Array schema fields - Items: tool.Function.Parameters.Items, - MinItems: tool.Function.Parameters.MinItems, - MaxItems: tool.Function.Parameters.MaxItems, - // Composition fields - AnyOf: tool.Function.Parameters.AnyOf, - OneOf: tool.Function.Parameters.OneOf, - AllOf: tool.Function.Parameters.AllOf, - // String validation fields - Format: tool.Function.Parameters.Format, - Pattern: tool.Function.Parameters.Pattern, - MinLength: tool.Function.Parameters.MinLength, - MaxLength: tool.Function.Parameters.MaxLength, - // Number validation fields - Minimum: tool.Function.Parameters.Minimum, - Maximum: tool.Function.Parameters.Maximum, - // Misc fields - Title: tool.Function.Parameters.Title, - Default: tool.Function.Parameters.Default, - Nullable: tool.Function.Parameters.Nullable, - } - } - - if anthropicTool.InputSchema != nil { - anthropicTool.InputSchema = anthropicTool.InputSchema.Normalized() - } - - if tool.CacheControl != nil { - anthropicTool.CacheControl = tool.CacheControl - } - - tools = append(tools, anthropicTool) } if anthropicReq.Tools == nil { anthropicReq.Tools = tools @@ -189,23 +471,28 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif // Convert reasoning if bifrostReq.Params.Reasoning != nil { if bifrostReq.Params.Reasoning.MaxTokens != nil { - budgetTokens := *bifrostReq.Params.Reasoning.MaxTokens - if *bifrostReq.Params.Reasoning.MaxTokens == -1 { - // anthropic does not support dynamic reasoning budget like gemini - // setting it to default max tokens - budgetTokens = MinimumReasoningMaxTokens - } - if budgetTokens < MinimumReasoningMaxTokens { - return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", MinimumReasoningMaxTokens) - } - anthropicReq.Thinking = &AnthropicThinking{ - Type: "enabled", - BudgetTokens: schemas.Ptr(budgetTokens), + if IsOpus47(bifrostReq.Model) { + // Opus 4.7+: budget_tokens removed; adaptive thinking is the only thinking-on mode. + anthropicReq.Thinking = &AnthropicThinking{Type: "adaptive"} + } else { + budgetTokens := *bifrostReq.Params.Reasoning.MaxTokens + if *bifrostReq.Params.Reasoning.MaxTokens == -1 { + // anthropic does not support dynamic reasoning budget like gemini + // setting it to default max tokens + budgetTokens = MinimumReasoningMaxTokens + } + if budgetTokens < MinimumReasoningMaxTokens { + return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", MinimumReasoningMaxTokens) + } + anthropicReq.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: schemas.Ptr(budgetTokens), + } } } else if bifrostReq.Params.Reasoning.Effort != nil && *bifrostReq.Params.Reasoning.Effort != "none" { effort := MapBifrostEffortToAnthropic(*bifrostReq.Params.Reasoning.Effort) - if SupportsAdaptiveThinking(bifrostReq.Model) { - // Opus 4.6+: adaptive thinking + native effort + if SupportsAdaptiveThinking(bifrostReq.Model) || IsOpus47(bifrostReq.Model) { + // Opus 4.6+ and Opus 4.7+: adaptive thinking + native effort anthropicReq.Thinking = &AnthropicThinking{Type: "adaptive"} setEffortOnOutputConfig(anthropicReq, effort) } else if SupportsNativeEffort(bifrostReq.Model) { @@ -235,6 +522,18 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif Type: "disabled", } } + + // thinking.display — map the neutral ChatReasoning.Display onto + // AnthropicThinking.Display. Valid for "enabled" and "adaptive" + // modes only; Anthropic rejects display on "disabled" ("there is + // nothing to display", per the extended-thinking doc). We attach + // on non-disabled modes and let the upstream provider enforce + // model-level support. + if bifrostReq.Params.Reasoning.Display != nil && + anthropicReq.Thinking != nil && + anthropicReq.Thinking.Type != "disabled" { + anthropicReq.Thinking.Display = bifrostReq.Params.Reasoning.Display + } } // Convert service tier @@ -407,6 +706,11 @@ func ToAnthropicChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.Bif anthropicReq.Messages = anthropicMessages anthropicReq.System = systemContent + // Strip request- and tool-level fields the target Anthropic-family + // provider does not support. Fail-closed tool validation stays in + // ValidateToolsForProvider; this is strip-silently for additive fields. + stripUnsupportedAnthropicFields(anthropicReq, bifrostReq.Provider, bifrostReq.Model) + return anthropicReq, nil } @@ -418,12 +722,8 @@ func (response *AnthropicMessageResponse) ToBifrostChatResponse(ctx *schemas.Bif // Initialize Bifrost response bifrostResponse := &schemas.BifrostChatResponse{ - ID: response.ID, - Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Anthropic, - }, + ID: response.ID, + Model: response.Model, Created: int(time.Now().Unix()), } diff --git a/core/providers/anthropic/chat_server_tools_test.go b/core/providers/anthropic/chat_server_tools_test.go new file mode 100644 index 0000000000..cf830fab02 --- /dev/null +++ b/core/providers/anthropic/chat_server_tools_test.go @@ -0,0 +1,366 @@ +package anthropic + +import ( + "encoding/json" + "testing" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// TestChatTool_ServerToolRoundTrip verifies that every Anthropic server-tool +// variant survives Marshal/Unmarshal through the neutral ChatTool schema. +// This locks in the fix for the user-reported bug where a raw JSON tool like +// {"type":"web_search_20260209","name":"web_search","max_uses":5} was being +// dropped at the neutral-schema layer because ChatTool had no slots for the +// server-tool metadata. +func TestChatTool_ServerToolRoundTrip(t *testing.T) { + five := 5 + ptrTrue := true + w, h := 1280, 800 + maxChars := 16000 + maxContent := 32000 + + cases := []struct { + name string + raw string + }{ + { + name: "web_search_20260209", + raw: `{"type":"web_search_20260209","name":"web_search","max_uses":5,"allowed_callers":["direct"]}`, + }, + { + name: "web_search_with_domains", + raw: `{"type":"web_search_20250305","name":"web_search","allowed_domains":["example.com","docs.example.com"]}`, + }, + { + name: "web_search_with_user_location", + raw: `{"type":"web_search_20250305","name":"web_search","user_location":{"type":"approximate","city":"San Francisco","country":"US","timezone":"America/Los_Angeles"}}`, + }, + { + name: "web_fetch_20260309", + raw: `{"type":"web_fetch_20260309","name":"web_fetch","max_uses":5,"max_content_tokens":32000,"citations":{"enabled":true},"use_cache":true}`, + }, + { + name: "computer_20251124", + raw: `{"type":"computer_20251124","name":"computer","display_width_px":1280,"display_height_px":800,"display_number":1,"enable_zoom":true}`, + }, + { + name: "text_editor_20250728", + raw: `{"type":"text_editor_20250728","name":"str_replace_based_edit_tool","max_characters":16000}`, + }, + { + name: "bash_20250124", + raw: `{"type":"bash_20250124","name":"bash"}`, + }, + { + name: "memory_20250818", + raw: `{"type":"memory_20250818","name":"memory"}`, + }, + { + name: "code_execution_20250825", + raw: `{"type":"code_execution_20250825","name":"code_execution"}`, + }, + { + name: "tool_search_tool_bm25", + raw: `{"type":"tool_search_tool_bm25","name":"tool_search_tool_bm25"}`, + }, + { + name: "mcp_toolset", + raw: `{"type":"mcp_toolset","name":"my_mcp","mcp_server_name":"notion","configs":{"search":{"enabled":true}}}`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Variant-specific field assertions. Invoked twice — once after + // initial decode, once after round-trip — so that a regression in + // MarshalSorted that silently drops any variant-specific field + // fails this test instead of sneaking through. + assertVariantFields := func(label string, tl schemas.ChatTool) { + t.Helper() + switch tc.name { + case "web_search_20260209": + if tl.MaxUses == nil || *tl.MaxUses != five { + t.Errorf("%s: MaxUses not preserved, got %v", label, tl.MaxUses) + } + if len(tl.AllowedCallers) != 1 || tl.AllowedCallers[0] != "direct" { + t.Errorf("%s: AllowedCallers not preserved, got %v", label, tl.AllowedCallers) + } + case "web_fetch_20260309": + if tl.MaxContentTokens == nil || *tl.MaxContentTokens != maxContent { + t.Errorf("%s: MaxContentTokens not preserved, got %v", label, tl.MaxContentTokens) + } + if tl.Citations == nil || tl.Citations.Enabled == nil || !*tl.Citations.Enabled { + t.Errorf("%s: Citations not preserved, got %v", label, tl.Citations) + } + if tl.UseCache == nil || !*tl.UseCache { + t.Errorf("%s: UseCache not preserved", label) + } + _ = ptrTrue + case "computer_20251124": + if tl.DisplayWidthPx == nil || *tl.DisplayWidthPx != w { + t.Errorf("%s: DisplayWidthPx not preserved, got %v", label, tl.DisplayWidthPx) + } + if tl.DisplayHeightPx == nil || *tl.DisplayHeightPx != h { + t.Errorf("%s: DisplayHeightPx not preserved, got %v", label, tl.DisplayHeightPx) + } + case "text_editor_20250728": + if tl.MaxCharacters == nil || *tl.MaxCharacters != maxChars { + t.Errorf("%s: MaxCharacters not preserved, got %v", label, tl.MaxCharacters) + } + case "mcp_toolset": + if tl.MCPServerName != "notion" { + t.Errorf("%s: MCPServerName not preserved, got %q", label, tl.MCPServerName) + } + if len(tl.Configs) != 1 { + t.Errorf("%s: Configs not preserved, got %v", label, tl.Configs) + } + } + } + + var tool schemas.ChatTool + if err := sonic.Unmarshal([]byte(tc.raw), &tool); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if string(tool.Type) == "" { + t.Errorf("Type should be preserved, got empty") + } + if tool.Name == "" { + t.Errorf("Name should be preserved, got empty") + } + assertVariantFields("first decode", tool) + + // Re-marshal and re-decode — all preserved fields should survive round trip. + out, err := schemas.MarshalSorted(tool) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + var tool2 schemas.ChatTool + if err := sonic.Unmarshal(out, &tool2); err != nil { + t.Fatalf("second unmarshal failed: %v\njson: %s", err, string(out)) + } + if tool.Name != tool2.Name || tool.Type != tool2.Type { + t.Errorf("round-trip mismatch\n in: %s\n out: %s", tc.raw, string(out)) + } + assertVariantFields("round trip", tool2) + }) + } +} + +// TestToAnthropicChatRequest_ServerTools verifies every ChatTool server-tool +// shape converts correctly through ToAnthropicChatRequest. +func TestToAnthropicChatRequest_ServerTools(t *testing.T) { + mk := func(rawTool string) *schemas.BifrostChatRequest { + var tool schemas.ChatTool + if err := sonic.Unmarshal([]byte(rawTool), &tool); err != nil { + t.Fatalf("test setup: %v", err) + } + return &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-sonnet-4-6", + Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}}, + Params: &schemas.ChatParameters{Tools: []schemas.ChatTool{tool}}, + } + } + + type check struct { + expectName string + expectType AnthropicToolType + expectWebSearch bool + expectWebFetch bool + expectComputer bool + expectTextEditor bool + expectMCPToolset bool + } + + cases := []struct { + name string + raw string + want check + }{ + { + name: "web_search", + raw: `{"type":"web_search_20260209","name":"web_search","max_uses":5}`, + want: check{expectName: "web_search", expectType: "web_search_20260209", expectWebSearch: true}, + }, + { + name: "web_fetch", + raw: `{"type":"web_fetch_20260309","name":"web_fetch","max_uses":3,"use_cache":true}`, + want: check{expectName: "web_fetch", expectType: "web_fetch_20260309", expectWebFetch: true}, + }, + { + name: "computer_20251124", + raw: `{"type":"computer_20251124","name":"computer","display_width_px":1280,"display_height_px":800}`, + want: check{expectName: "computer", expectType: "computer_20251124", expectComputer: true}, + }, + { + name: "text_editor_20250728", + raw: `{"type":"text_editor_20250728","name":"str_replace_based_edit_tool","max_characters":16000}`, + want: check{expectName: "str_replace_based_edit_tool", expectType: "text_editor_20250728", expectTextEditor: true}, + }, + { + name: "bash_20250124", + raw: `{"type":"bash_20250124","name":"bash"}`, + want: check{expectName: "bash", expectType: "bash_20250124"}, + }, + { + name: "mcp_toolset", + raw: `{"type":"mcp_toolset","name":"notion","mcp_server_name":"notion"}`, + want: check{expectMCPToolset: true}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := mk(tc.raw) + out, err := ToAnthropicChatRequest(nil, req) + if err != nil { + t.Fatalf("conversion failed: %v", err) + } + if len(out.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d (raw: %s)", len(out.Tools), tc.raw) + } + at := out.Tools[0] + if tc.want.expectMCPToolset { + if at.MCPToolset == nil { + t.Errorf("expected MCPToolset to be set") + } + return + } + if at.Name != tc.want.expectName { + t.Errorf("Name: got %q want %q", at.Name, tc.want.expectName) + } + if at.Type == nil || *at.Type != tc.want.expectType { + t.Errorf("Type: got %v want %q", at.Type, tc.want.expectType) + } + if tc.want.expectWebSearch && at.AnthropicToolWebSearch == nil { + t.Errorf("expected AnthropicToolWebSearch populated") + } + if tc.want.expectWebFetch && at.AnthropicToolWebFetch == nil { + t.Errorf("expected AnthropicToolWebFetch populated") + } + if tc.want.expectComputer && at.AnthropicToolComputerUse == nil { + t.Errorf("expected AnthropicToolComputerUse populated") + } + if tc.want.expectTextEditor && at.AnthropicToolTextEditor == nil { + t.Errorf("expected AnthropicToolTextEditor populated") + } + }) + } +} + +// TestToBifrostResponsesRequest_MCPToolsetPreservesAnthropicFlags verifies +// that when an Anthropic request carries an mcp_toolset tool with the four +// Anthropic-native flags (DeferLoading, AllowedCallers, InputExamples, +// EagerInputStreaming), those flags survive the inbound conversion into the +// neutral ResponsesTool on the mcp_servers merge path. Before the fix, the +// merge path only applied MCP configs (allowlist/cache-control) and dropped +// the flags because convertAnthropicToolToBifrost skips mcp_toolset entries. +func TestToBifrostResponsesRequest_MCPToolsetPreservesAnthropicFlags(t *testing.T) { + toolsetType := "mcp_toolset" + _ = toolsetType // shape documentation only; AnthropicTool.Type is pointer-to-enum and left nil for mcp_toolset + + req := &AnthropicMessageRequest{ + Model: "claude-sonnet-4-6", + Tools: []AnthropicTool{ + { + Name: "notion", + DeferLoading: schemas.Ptr(true), + AllowedCallers: []string{"direct", "agent"}, + EagerInputStreaming: schemas.Ptr(false), + InputExamples: []AnthropicToolInputExample{ + {Input: json.RawMessage(`{"q":"hello"}`), Description: schemas.Ptr("basic")}, + }, + MCPToolset: &AnthropicMCPToolsetTool{ + Type: "mcp_toolset", + MCPServerName: "notion", + DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: schemas.Ptr(true)}, + }, + }, + }, + MCPServers: []AnthropicMCPServerV2{ + {Type: "url", URL: "https://mcp.example.com", Name: "notion"}, + }, + } + + got := req.ToBifrostResponsesRequest(nil) + if got == nil || got.Params == nil { + t.Fatalf("ToBifrostResponsesRequest returned nil params") + } + + // The mcp_toolset tool should have been dropped by convertAnthropicToolToBifrost + // and re-created on the mcp_servers merge path — end result: exactly one tool, + // of type mcp, carrying the Anthropic flags we set. + if len(got.Params.Tools) != 1 { + t.Fatalf("expected 1 mcp tool after merge, got %d", len(got.Params.Tools)) + } + mcp := got.Params.Tools[0] + if mcp.Type != schemas.ResponsesToolTypeMCP { + t.Errorf("expected MCP tool, got type=%q", mcp.Type) + } + if mcp.DeferLoading == nil || !*mcp.DeferLoading { + t.Errorf("DeferLoading dropped on mcp_toolset merge path") + } + if len(mcp.AllowedCallers) != 2 || mcp.AllowedCallers[0] != "direct" { + t.Errorf("AllowedCallers dropped on mcp_toolset merge path, got %v", mcp.AllowedCallers) + } + if len(mcp.InputExamples) != 1 { + t.Errorf("InputExamples dropped on mcp_toolset merge path, got len=%d", len(mcp.InputExamples)) + } + if mcp.EagerInputStreaming == nil || *mcp.EagerInputStreaming { + t.Errorf("EagerInputStreaming dropped on mcp_toolset merge path, got %v", mcp.EagerInputStreaming) + } +} + +// TestToAnthropicChatRequest_ServerTools_ReproUserBug is the exact shape +// from the reported curl — web_search_20260209 with max_uses + allowed_callers. +// Verifies the request reaches ToAnthropicChatRequest output with a populated +// tools array (previously it was silently dropped). +func TestToAnthropicChatRequest_ServerTools_ReproUserBug(t *testing.T) { + raw := []byte(`{ + "model":"claude-sonnet-4-6", + "messages":[{"role":"user","content":"What is the weather in SF?"}], + "tools":[{"name":"web_search","type":"web_search_20260209","max_uses":5,"allowed_callers":["direct"]}] + }`) + // Unmarshal through the neutral schema the way the OpenAI endpoint does. + var inner struct { + Model string `json:"model"` + Messages []json.RawMessage `json:"messages"` + Tools []schemas.ChatTool `json:"tools"` + } + if err := sonic.Unmarshal(raw, &inner); err != nil { + t.Fatalf("outer unmarshal: %v", err) + } + if len(inner.Tools) != 1 { + t.Fatalf("setup: expected 1 tool in raw JSON, got %d", len(inner.Tools)) + } + if inner.Tools[0].Name == "" { + t.Errorf("Name lost at neutral-schema decode (was the bug). Got: %+v", inner.Tools[0]) + } + if inner.Tools[0].MaxUses == nil { + t.Errorf("MaxUses lost at neutral-schema decode (was the bug)") + } + + req := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: inner.Model, + Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}}, + Params: &schemas.ChatParameters{Tools: inner.Tools}, + } + out, err := ToAnthropicChatRequest(nil, req) + if err != nil { + t.Fatalf("conversion failed: %v", err) + } + if len(out.Tools) != 1 { + t.Fatalf("repro bug: expected 1 tool after conversion, got %d (tools array was empty — this was the bug)", len(out.Tools)) + } + if out.Tools[0].Name != "web_search" { + t.Errorf("tool Name: got %q, want %q", out.Tools[0].Name, "web_search") + } + if out.Tools[0].AnthropicToolWebSearch == nil || + out.Tools[0].AnthropicToolWebSearch.MaxUses == nil || + *out.Tools[0].AnthropicToolWebSearch.MaxUses != 5 { + t.Errorf("tool max_uses lost: %+v", out.Tools[0]) + } +} diff --git a/core/providers/anthropic/chat_test.go b/core/providers/anthropic/chat_test.go index 4d0ea9ac45..b73002009b 100644 --- a/core/providers/anthropic/chat_test.go +++ b/core/providers/anthropic/chat_test.go @@ -85,7 +85,7 @@ func TestToAnthropicChatRequest_CachingDeterminism(t *testing.T) { Model: "claude-sonnet-4-20250514", Input: []schemas.ChatMessage{{ Role: schemas.ChatMessageRoleUser, - Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")}, + Content: &schemas.ChatMessageContent{ContentStr: new("test")}, }}, Params: &schemas.ChatParameters{ Tools: []schemas.ChatTool{{ @@ -337,6 +337,115 @@ func TestToAnthropicChatRequest_ToolInputKeyOrderPreservation(t *testing.T) { } } +func TestToBifrostChatResponse_MultipleTextBlocksWithThinking(t *testing.T) { + thinkingText := "Let me reason step by step about this problem." + textBlock1 := "The answer is 42." + textBlock2 := "Here is why that is the case." + signature := "sig_abc123" + + response := &AnthropicMessageResponse{ + ID: "msg_test123", + Type: "message", + Role: "assistant", + Model: "claude-opus-4-6-20250514", + Content: []AnthropicContentBlock{ + { + Type: AnthropicContentBlockTypeThinking, + Thinking: &thinkingText, + Signature: &signature, + }, + { + Type: AnthropicContentBlockTypeText, + Text: &textBlock1, + }, + { + Type: AnthropicContentBlockTypeText, + Text: &textBlock2, + }, + }, + StopReason: "end_turn", + Usage: &AnthropicUsage{ + InputTokens: 100, + OutputTokens: 50, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result := response.ToBifrostChatResponse(ctx) + + if result == nil { + t.Fatal("expected non-nil result") + } + + // Content should be a combined string, not blocks + choice := result.Choices[0] + msg := choice.ChatNonStreamResponseChoice.Message + if msg.Content.ContentBlocks != nil { + t.Error("expected ContentBlocks to be nil (combined into string)") + } + if msg.Content.ContentStr == nil { + t.Fatal("expected ContentStr to be non-nil") + } + + // Combined string: thinking first, then text blocks + expected := thinkingText + "\n\n" + textBlock1 + "\n\n" + textBlock2 + if *msg.Content.ContentStr != expected { + t.Errorf("expected combined content:\n%s\ngot:\n%s", expected, *msg.Content.ContentStr) + } + + // Reasoning field should still have thinking text + if msg.ChatAssistantMessage == nil { + t.Fatal("expected ChatAssistantMessage to be non-nil") + } + if msg.ChatAssistantMessage.Reasoning == nil { + t.Fatal("expected Reasoning to be non-nil") + } + + // ReasoningDetails should have: signature-only thinking entry + content blocks boundary + rd := msg.ChatAssistantMessage.ReasoningDetails + if len(rd) < 2 { + t.Fatalf("expected at least 2 reasoning details entries, got %d", len(rd)) + } + + // First entry: thinking with signature, no text (text was cleared) + if rd[0].Type != schemas.BifrostReasoningDetailsTypeText { + t.Errorf("expected first reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type) + } + if rd[0].Signature == nil || *rd[0].Signature != signature { + t.Error("expected signature to be preserved") + } + if rd[0].Text != nil { + t.Error("expected thinking text to be nil (cleared to avoid duplication)") + } + + // Last entry: content blocks boundary + lastRD := rd[len(rd)-1] + if lastRD.Type != schemas.BifrostReasoningDetailsTypeContentBlocks { + t.Errorf("expected last reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeContentBlocks, lastRD.Type) + } + if lastRD.Text == nil { + t.Fatal("expected content blocks metadata to be non-nil") + } + + // var meta []contentBlockMeta + // if err := json.Unmarshal([]byte(*lastRD.Text), &meta); err != nil { + // t.Fatalf("failed to unmarshal block metadata: %v", err) + // } + // if len(meta) != 3 { + // t.Fatalf("expected 3 block metadata entries, got %d", len(meta)) + // } + // if meta[0].T != "thinking" || meta[0].L != len(thinkingText) { + // t.Errorf("block 0: expected thinking/%d, got %s/%d", len(thinkingText), meta[0].T, meta[0].L) + // } + // if meta[1].T != "text" || meta[1].L != len(textBlock1) { + // t.Errorf("block 1: expected text/%d, got %s/%d", len(textBlock1), meta[1].T, meta[1].L) + // } + // if meta[2].T != "text" || meta[2].L != len(textBlock2) { + // t.Errorf("block 2: expected text/%d, got %s/%d", len(textBlock2), meta[2].T, meta[2].L) + // } +} + func TestToBifrostChatResponse_SingleTextBlockNoThinking(t *testing.T) { // Verify existing behavior: single text block without thinking collapses to string text := "Simple response" @@ -511,3 +620,163 @@ func TestToAnthropicChatRequest_NormalFlowUnchanged(t *testing.T) { t.Errorf("block 1: expected text %q, got %v", responseText, blocks[1].Text) } } + +func TestToAnthropicChatRequest_Opus47_StripsTemperatureTopPTopK(t *testing.T) { + temp := 0.7 + topP := 0.9 + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-opus-4-7-20260401", + Input: []schemas.ChatMessage{ + {Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}, + }, + Params: &schemas.ChatParameters{ + Temperature: &temp, + TopP: &topP, + ExtraParams: map[string]interface{}{"top_k": 40}, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result, err := ToAnthropicChatRequest(ctx, bifrostReq) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Temperature != nil { + t.Errorf("expected Temperature to be nil for Opus 4.7, got %v", result.Temperature) + } + if result.TopP != nil { + t.Errorf("expected TopP to be nil for Opus 4.7, got %v", result.TopP) + } + if result.TopK != nil { + t.Errorf("expected TopK to be nil for Opus 4.7, got %v", result.TopK) + } +} + +func TestToAnthropicChatRequest_NonOpus47_PreservesTemperature(t *testing.T) { + temp := 0.7 + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-opus-4-6-20250514", + Input: []schemas.ChatMessage{ + {Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}, + }, + Params: &schemas.ChatParameters{ + Temperature: &temp, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result, err := ToAnthropicChatRequest(ctx, bifrostReq) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Temperature == nil || *result.Temperature != temp { + t.Errorf("expected Temperature %v, got %v", temp, result.Temperature) + } +} + +func TestToAnthropicChatRequest_Opus47_ReasoningMaxTokens_AdaptiveOnly(t *testing.T) { + maxTok := 2048 + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-opus-4-7-20260401", + Input: []schemas.ChatMessage{ + {Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}}, + }, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: schemas.Ptr(8192), + Reasoning: &schemas.ChatReasoning{MaxTokens: &maxTok}, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result, err := ToAnthropicChatRequest(ctx, bifrostReq) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Thinking == nil { + t.Fatal("expected Thinking to be set") + } + if result.Thinking.Type != "adaptive" { + t.Errorf("expected thinking type 'adaptive' for Opus 4.7, got %q", result.Thinking.Type) + } + if result.Thinking.BudgetTokens != nil { + t.Errorf("expected BudgetTokens to be nil for Opus 4.7, got %v", result.Thinking.BudgetTokens) + } +} + +func TestToAnthropicChatRequest_NonOpus47_ReasoningMaxTokens_EnabledWithBudget(t *testing.T) { + maxTok := 2048 + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-opus-4-6-20250514", + Input: []schemas.ChatMessage{ + {Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}}, + }, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: schemas.Ptr(8192), + Reasoning: &schemas.ChatReasoning{MaxTokens: &maxTok}, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result, err := ToAnthropicChatRequest(ctx, bifrostReq) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Thinking == nil { + t.Fatal("expected Thinking to be set") + } + if result.Thinking.Type != "enabled" { + t.Errorf("expected thinking type 'enabled' for Opus 4.6, got %q", result.Thinking.Type) + } + if result.Thinking.BudgetTokens == nil || *result.Thinking.BudgetTokens != maxTok { + t.Errorf("expected BudgetTokens %d, got %v", maxTok, result.Thinking.BudgetTokens) + } +} + +func TestToAnthropicChatRequest_Opus47_ReasoningEffort_AdaptiveWithEffort(t *testing.T) { + effort := "high" + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-opus-4-7-20260401", + Input: []schemas.ChatMessage{ + {Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}}, + }, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: schemas.Ptr(8192), + Reasoning: &schemas.ChatReasoning{Effort: &effort}, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + result, err := ToAnthropicChatRequest(ctx, bifrostReq) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Thinking == nil { + t.Fatal("expected Thinking to be set") + } + if result.Thinking.Type != "adaptive" { + t.Errorf("expected thinking type 'adaptive' for Opus 4.7 effort-based, got %q", result.Thinking.Type) + } + if result.OutputConfig == nil || result.OutputConfig.Effort == nil { + t.Error("expected OutputConfig.Effort to be set for Opus 4.7 effort-based reasoning") + } +} diff --git a/core/providers/anthropic/errors.go b/core/providers/anthropic/errors.go index dd1dfaf698..81bbd49d0c 100644 --- a/core/providers/anthropic/errors.go +++ b/core/providers/anthropic/errors.go @@ -54,7 +54,7 @@ func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string { return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) } -func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp AnthropicError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Error != nil { @@ -64,10 +64,5 @@ func parseAnthropicError(resp *fasthttp.Response, meta *providerUtils.RequestMet bifrostErr.Error.Type = &errorResp.Error.Type bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/anthropic/models.go b/core/providers/anthropic/models.go index 3da2f6458b..3815a0244b 100644 --- a/core/providers/anthropic/models.go +++ b/core/providers/anthropic/models.go @@ -1,13 +1,14 @@ package anthropic import ( + "strings" "time" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -19,57 +20,51 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide HasMore: schemas.Ptr(response.HasMore), } - // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination - // If there are more results, set next_page_token to last_id so it can be used in the next request + // Map Anthropic's cursor-based pagination to Bifrost's token-based pagination. + // If there are more results, set next_page_token to last_id for the next request. if response.HasMore && response.LastID != nil { bifrostResponse.NextPageToken = *response.LastID } - includedModels := make(map[string]bool) - for _, model := range response.Data { - modelID := model.ID - if !unfiltered && len(allowedModels) > 0 { - allowed := false - for _, allowedModel := range allowedModels { - if schemas.SameBaseModel(model.ID, allowedModel) { - modelID = allowedModel - allowed = true - break - } - } - if !allowed { - continue - } - } - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, modelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Created: schemas.Ptr(model.CreatedAt.Unix()), - MaxInputTokens: model.MaxInputTokens, - MaxOutputTokens: model.MaxTokens, - ProviderExtra: model.Capabilities, - }) - includedModels[modelID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) - // Backfill allowed models that were not in the response (skip blacklisted; blacklist wins over allow list) - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { continue } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Created: schemas.Ptr(model.CreatedAt.Unix()), + MaxInputTokens: model.MaxInputTokens, + MaxOutputTokens: model.MaxTokens, + ProviderExtra: model.Capabilities, } + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[resolvedKey] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/anthropic/responses.go b/core/providers/anthropic/responses.go index ddc09fd912..ad05866e69 100644 --- a/core/providers/anthropic/responses.go +++ b/core/providers/anthropic/responses.go @@ -1445,13 +1445,13 @@ func ToAnthropicResponsesStreamResponse(ctx *schemas.BifrostContext, bifrostResp if bifrostResp.Response.ID != nil { streamMessage.ID = *bifrostResp.Response.ID } - // Preserve model from Response if available, otherwise use ExtraFields - if bifrostResp.ExtraFields.ModelRequested != "" { - if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { - streamMessage.Model = bifrostResp.Response.Model - } else { - streamMessage.Model = bifrostResp.ExtraFields.ModelRequested - } + // Prefer Response.Model, then ResolvedModelUsed, then OriginalModelRequested + if bifrostResp.Response != nil && bifrostResp.Response.Model != "" { + streamMessage.Model = bifrostResp.Response.Model + } else if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + streamMessage.Model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if bifrostResp.ExtraFields.OriginalModelRequested != "" { + streamMessage.Model = bifrostResp.ExtraFields.OriginalModelRequested } streamResp.Message = streamMessage } @@ -2158,6 +2158,9 @@ func (req *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx *schemas.Bifro // GA structured outputs - OutputConfig.Format has same structure as OutputFormat params.Text = convertAnthropicOutputFormatToResponsesTextConfig(req.OutputConfig.Format) } + if req.OutputConfig != nil && req.OutputConfig.TaskBudget != nil { + params.ExtraParams["task_budget"] = req.OutputConfig.TaskBudget + } if req.Thinking != nil { if req.Thinking.Type == "enabled" || req.Thinking.Type == "adaptive" { var summary *string @@ -2170,10 +2173,14 @@ func (req *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx *schemas.Bifro summary = schemas.Ptr("detailed") } } + // If the request was sent with display:"omitted" + if req.Thinking.Display != nil && *req.Thinking.Display == "omitted" { + summary = schemas.Ptr("none") + } if req.OutputConfig != nil && req.OutputConfig.Effort != nil { // Native effort present — map to Bifrost enum (e.g., "max" → "high") params.Reasoning = &schemas.ResponsesParametersReasoning{ - Effort: schemas.Ptr(MapAnthropicEffortToBifrost(*req.OutputConfig.Effort)), + Effort: schemas.Ptr(*req.OutputConfig.Effort), MaxTokens: req.Thinking.BudgetTokens, Summary: summary, } @@ -2232,6 +2239,7 @@ func (req *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx *schemas.Bifro for _, tool := range req.Tools { bifrostTool := convertAnthropicToolToBifrost(&tool) if bifrostTool != nil { + applyAnthropicToolFlagsToResponsesTool(&tool, bifrostTool) bifrostTools = append(bifrostTools, *bifrostTool) } } @@ -2241,12 +2249,17 @@ func (req *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx *schemas.Bifro } if req.MCPServers != nil { - // Build a map of mcp_toolset configs from tools[] keyed by mcp_server_name - toolsetByServer := make(map[string]*AnthropicMCPToolsetTool) + // Build a map of mcp_toolset entries from tools[] keyed by mcp_server_name. + // Stores the full *AnthropicTool (not just *AnthropicMCPToolsetTool) so + // top-level Anthropic tool flags (DeferLoading, AllowedCallers, + // InputExamples, EagerInputStreaming) survive the mcp_servers merge path — + // without this, mcp_toolset tools bypass applyAnthropicToolFlagsToResponsesTool + // because convertAnthropicToolToBifrost skips them. + toolsetByServer := make(map[string]*AnthropicTool) if req.Tools != nil { for i := range req.Tools { if req.Tools[i].MCPToolset != nil { - toolsetByServer[req.Tools[i].MCPToolset.MCPServerName] = req.Tools[i].MCPToolset + toolsetByServer[req.Tools[i].MCPToolset.MCPServerName] = &req.Tools[i] } } } @@ -2255,9 +2268,10 @@ func (req *AnthropicMessageRequest) ToBifrostResponsesRequest(ctx *schemas.Bifro for _, mcpServer := range req.MCPServers { bifrostMCPTool := convertAnthropicMCPServerV2ToBifrostTool(&mcpServer) if bifrostMCPTool != nil { - // Merge mcp_toolset configs (allowed tools) if present - if toolset, ok := toolsetByServer[mcpServer.Name]; ok { - applyMCPToolsetConfigToBifrostTool(bifrostMCPTool, toolset) + // Merge mcp_toolset configs (allowed tools) + Anthropic tool flags if present + if toolWithFlags, ok := toolsetByServer[mcpServer.Name]; ok { + applyMCPToolsetConfigToBifrostTool(bifrostMCPTool, toolWithFlags.MCPToolset) + applyAnthropicToolFlagsToResponsesTool(toolWithFlags, bifrostMCPTool) } bifrostMCPTools = append(bifrostMCPTools, *bifrostMCPTool) } @@ -2299,12 +2313,15 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema if bifrostReq.Params.MaxOutputTokens != nil { anthropicReq.MaxTokens = *bifrostReq.Params.MaxOutputTokens } - // Anthropic doesn't allow both temperature and top_p to be specified - // If both are present, prefer temperature (more commonly used) - if bifrostReq.Params.Temperature != nil { - anthropicReq.Temperature = bifrostReq.Params.Temperature - } else if bifrostReq.Params.TopP != nil { - anthropicReq.TopP = bifrostReq.Params.TopP + // Opus 4.7+ rejects temperature, top_p, and top_k with a 400 error. + if !IsOpus47(bifrostReq.Model) { + // Anthropic doesn't allow both temperature and top_p to be specified. + // If both are present, prefer temperature (more commonly used). + if bifrostReq.Params.Temperature != nil { + anthropicReq.Temperature = bifrostReq.Params.Temperature + } else if bifrostReq.Params.TopP != nil { + anthropicReq.TopP = bifrostReq.Params.TopP + } } if bifrostReq.Params.User != nil { anthropicReq.Metadata = &AnthropicMetaData{ @@ -2364,26 +2381,31 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema } if bifrostReq.Params.Reasoning != nil { if bifrostReq.Params.Reasoning.MaxTokens != nil { - budgetTokens := *bifrostReq.Params.Reasoning.MaxTokens - if *bifrostReq.Params.Reasoning.MaxTokens == -1 { - // anthropic does not support dynamic reasoning budget like gemini - // setting it to default max tokens - budgetTokens = MinimumReasoningMaxTokens - } - if budgetTokens < MinimumReasoningMaxTokens { - return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", MinimumReasoningMaxTokens) - } - anthropicReq.Thinking = &AnthropicThinking{ - Type: "enabled", - BudgetTokens: schemas.Ptr(budgetTokens), + if IsOpus47(bifrostReq.Model) { + // Opus 4.7+: budget_tokens removed; adaptive thinking is the only thinking-on mode. + anthropicReq.Thinking = &AnthropicThinking{Type: "adaptive"} + } else { + budgetTokens := *bifrostReq.Params.Reasoning.MaxTokens + if *bifrostReq.Params.Reasoning.MaxTokens == -1 { + // anthropic does not support dynamic reasoning budget like gemini + // setting it to default max tokens + budgetTokens = MinimumReasoningMaxTokens + } + if budgetTokens < MinimumReasoningMaxTokens { + return nil, fmt.Errorf("reasoning.max_tokens must be >= %d for anthropic", MinimumReasoningMaxTokens) + } + anthropicReq.Thinking = &AnthropicThinking{ + Type: "enabled", + BudgetTokens: schemas.Ptr(budgetTokens), + } } } else { if bifrostReq.Params.Reasoning.Effort != nil { if *bifrostReq.Params.Reasoning.Effort != "none" { effort := MapBifrostEffortToAnthropic(*bifrostReq.Params.Reasoning.Effort) - if SupportsAdaptiveThinking(bifrostReq.Model) { - // Opus 4.6+: adaptive thinking + native effort + if SupportsAdaptiveThinking(bifrostReq.Model) || IsOpus47(bifrostReq.Model) { + // Opus 4.6+ and Opus 4.7+: adaptive thinking + native effort anthropicReq.Thinking = &AnthropicThinking{Type: "adaptive"} setEffortOnOutputConfig(anthropicReq, effort) } else if SupportsNativeEffort(bifrostReq.Model) { @@ -2415,6 +2437,15 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema } } } + if anthropicReq.Thinking != nil && anthropicReq.Thinking.Type != "disabled" { + if bifrostReq.Params.Reasoning != nil && + bifrostReq.Params.Reasoning.Summary != nil && *bifrostReq.Params.Reasoning.Summary == "none" { + anthropicReq.Thinking.Display = schemas.Ptr("omitted") + } else { + // Default to "summarized" to preserve visible thinking output + anthropicReq.Thinking.Display = schemas.Ptr("summarized") + } + } } // Convert service tier anthropicReq.ServiceTier = bifrostReq.Params.ServiceTier @@ -2449,7 +2480,9 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]) if ok { delete(anthropicReq.ExtraParams, "top_k") - anthropicReq.TopK = topK + if !IsOpus47(bifrostReq.Model) { + anthropicReq.TopK = topK + } } if speed, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["speed"]); ok { delete(anthropicReq.ExtraParams, "speed") @@ -2475,6 +2508,31 @@ func ToAnthropicResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schema } } } + if tbVal, exists := bifrostReq.Params.ExtraParams["task_budget"]; exists { + // Always consume provider-specific key from passthrough extras. + delete(anthropicReq.ExtraParams, "task_budget") + var taskBudget *AnthropicTaskBudget + switch v := tbVal.(type) { + case *AnthropicTaskBudget: + taskBudget = v + case AnthropicTaskBudget: + taskBudget = &v + default: + if data, err := providerUtils.MarshalSorted(v); err == nil { + var tb AnthropicTaskBudget + if sonic.Unmarshal(data, &tb) == nil { + taskBudget = &tb + } + } + } + if taskBudget == nil { + return nil, fmt.Errorf("invalid task_budget format for anthropic") + } + if anthropicReq.OutputConfig == nil { + anthropicReq.OutputConfig = &AnthropicOutputConfig{} + } + anthropicReq.OutputConfig.TaskBudget = taskBudget + } } // Convert tools @@ -3386,7 +3444,7 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant case AnthropicContentBlockTypeImage: // Don't emit accumulated text or tool_use blocks for images - if block.Source != nil { + if block.Source != nil && block.Source.SourceObj != nil { bifrostMsg := schemas.ResponsesMessage{ Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, @@ -3402,7 +3460,7 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant case AnthropicContentBlockTypeDocument: // Handle document blocks similar to images - if block.Source != nil { + if block.Source != nil && block.Source.SourceObj != nil { bifrostMsg := schemas.ResponsesMessage{ Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, @@ -3492,7 +3550,7 @@ func convertAnthropicContentBlocksToResponsesMessagesGrouped(contentBlocks []Ant }) } case AnthropicContentBlockTypeImage: - if contentBlock.Source != nil { + if contentBlock.Source != nil && contentBlock.Source.SourceObj != nil { toolMsgContentBlocks = append(toolMsgContentBlocks, contentBlock.toBifrostResponsesImageBlock()) } } @@ -3705,7 +3763,7 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex bifrostMessages = append(bifrostMessages, bifrostMsg) } case AnthropicContentBlockTypeImage: - if block.Source != nil { + if block.Source != nil && block.Source.SourceObj != nil { bifrostMsg := schemas.ResponsesMessage{ Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, @@ -3719,7 +3777,7 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex bifrostMessages = append(bifrostMessages, bifrostMsg) } case AnthropicContentBlockTypeDocument: - if block.Source != nil { + if block.Source != nil && block.Source.SourceObj != nil { bifrostMsg := schemas.ResponsesMessage{ Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), Role: role, @@ -3853,7 +3911,7 @@ func convertAnthropicContentBlocksToResponsesMessages(ctx *schemas.BifrostContex }) } case AnthropicContentBlockTypeImage: - if contentBlock.Source != nil { + if contentBlock.Source != nil && contentBlock.Source.SourceObj != nil { toolMsgContentBlocks = append(toolMsgContentBlocks, contentBlock.toBifrostResponsesImageBlock()) } } @@ -4800,18 +4858,80 @@ func convertBifrostToolsToAnthropic(model string, tools []schemas.ResponsesTool, mcpServers = append(mcpServers, *server) } if toolset != nil { - anthropicTools = append(anthropicTools, AnthropicTool{MCPToolset: toolset}) + mcpTool := AnthropicTool{MCPToolset: toolset} + applyResponsesToolAnthropicFlags(&mcpTool, &tool) + anthropicTools = append(anthropicTools, mcpTool) } continue } anthropicTool := convertBifrostToolToAnthropic(model, &tool, provider, hasWebSearchOrFetch) if anthropicTool != nil { + applyResponsesToolAnthropicFlags(anthropicTool, &tool) anthropicTools = append(anthropicTools, *anthropicTool) } } return anthropicTools, mcpServers } +// applyAnthropicToolFlagsToResponsesTool propagates the Anthropic-native tool +// flags (DeferLoading, AllowedCallers, InputExamples, EagerInputStreaming) in +// the inbound direction: from the incoming AnthropicTool onto the neutral +// ResponsesTool when the native Anthropic /v1/messages endpoint is the entry +// point. Called once per converted tool so every return path inside +// convertAnthropicToolToBifrost benefits. +func applyAnthropicToolFlagsToResponsesTool(at *AnthropicTool, rt *schemas.ResponsesTool) { + if at == nil || rt == nil { + return + } + if at.DeferLoading != nil { + rt.DeferLoading = at.DeferLoading + } + if len(at.AllowedCallers) > 0 { + rt.AllowedCallers = at.AllowedCallers + } + if len(at.InputExamples) > 0 { + rt.InputExamples = make([]schemas.ChatToolInputExample, len(at.InputExamples)) + for i, ex := range at.InputExamples { + rt.InputExamples[i] = schemas.ChatToolInputExample{ + Input: ex.Input, + Description: ex.Description, + } + } + } + if at.EagerInputStreaming != nil { + rt.EagerInputStreaming = at.EagerInputStreaming + } +} + +// applyResponsesToolAnthropicFlags propagates the Anthropic-native tool flags +// (DeferLoading, AllowedCallers, InputExamples, EagerInputStreaming) from the +// neutral ResponsesTool onto the provider-native AnthropicTool. Called once +// per converted tool so every branch in convertBifrostToolToAnthropic +// benefits without duplicating the logic on each return path. +func applyResponsesToolAnthropicFlags(at *AnthropicTool, rt *schemas.ResponsesTool) { + if at == nil || rt == nil { + return + } + if rt.DeferLoading != nil { + at.DeferLoading = rt.DeferLoading + } + if len(rt.AllowedCallers) > 0 { + at.AllowedCallers = rt.AllowedCallers + } + if len(rt.InputExamples) > 0 { + at.InputExamples = make([]AnthropicToolInputExample, len(rt.InputExamples)) + for i, ex := range rt.InputExamples { + at.InputExamples[i] = AnthropicToolInputExample{ + Input: ex.Input, + Description: ex.Description, + } + } + } + if rt.EagerInputStreaming != nil { + at.EagerInputStreaming = rt.EagerInputStreaming + } +} + // Helper function to convert Tool back to AnthropicTool func convertBifrostToolToAnthropic(model string, tool *schemas.ResponsesTool, provider schemas.ModelProvider, hasWebSearchOrFetch bool) *AnthropicTool { if tool == nil { @@ -5149,36 +5269,40 @@ func (block AnthropicContentBlock) toBifrostResponsesDocumentBlock() schemas.Res resultBlock.ResponsesInputMessageContentBlockFile.Filename = block.Title } - if block.Source == nil { + if block.Source == nil || block.Source.SourceObj == nil { + // File-block rendering only applies to object-form sources + // (image / document). String-form sources (search_result) are + // handled elsewhere. return resultBlock } + src := block.Source.SourceObj // Handle different source types - switch block.Source.Type { + switch src.Type { case "url": // URL source - if block.Source.URL != nil { - resultBlock.ResponsesInputMessageContentBlockFile.FileURL = block.Source.URL + if src.URL != nil { + resultBlock.ResponsesInputMessageContentBlockFile.FileURL = src.URL } case "base64": // Base64 encoded data - if block.Source.Data != nil { + if src.Data != nil { // Construct data URL with media type mediaType := "application/pdf" - if block.Source.MediaType != nil { - mediaType = *block.Source.MediaType + if src.MediaType != nil { + mediaType = *src.MediaType } - dataURL := *block.Source.Data + dataURL := *src.Data if !strings.HasPrefix(dataURL, "data:") { - dataURL = "data:" + mediaType + ";base64," + *block.Source.Data + dataURL = "data:" + mediaType + ";base64," + *src.Data } resultBlock.ResponsesInputMessageContentBlockFile.FileData = &dataURL } case "text": // Plain text source - if block.Source.Data != nil { + if src.Data != nil { resultBlock.ResponsesInputMessageContentBlockFile.FileType = schemas.Ptr("text/plain") - resultBlock.ResponsesInputMessageContentBlockFile.FileData = block.Source.Data + resultBlock.ResponsesInputMessageContentBlockFile.FileData = src.Data } } diff --git a/core/providers/anthropic/text.go b/core/providers/anthropic/text.go index 3228ad49f6..39a700499b 100644 --- a/core/providers/anthropic/text.go +++ b/core/providers/anthropic/text.go @@ -103,10 +103,6 @@ func (response *AnthropicTextResponse) ToBifrostTextCompletionResponse() *schema TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, }, Model: response.Model, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Anthropic, - }, } } diff --git a/core/providers/anthropic/types.go b/core/providers/anthropic/types.go index f803c337e5..d2a636de8d 100644 --- a/core/providers/anthropic/types.go +++ b/core/providers/anthropic/types.go @@ -26,6 +26,12 @@ const ( AnthropicStructuredOutputsBetaHeader = "structured-outputs-2025-11-13" // AnthropicAdvancedToolUseBetaHeader is required for defer_loading, input_examples, and allowed_callers. AnthropicAdvancedToolUseBetaHeader = "advanced-tool-use-2025-11-20" + // AnthropicToolExamplesBetaHeader is required for tool.input_examples as a + // standalone feature (Bedrock supports this narrow header without the full + // advanced-tool-use-2025-11-20 bundle). + // Source: AWS Bedrock user guide beta-header list: + // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + AnthropicToolExamplesBetaHeader = "tool-examples-2025-10-29" // AnthropicMCPClientBetaHeader is required for MCP servers (current version). AnthropicMCPClientBetaHeader = "mcp-client-2025-11-20" // AnthropicMCPClientBetaHeaderDeprecated is the previous MCP beta header (kept for fallback). @@ -48,6 +54,12 @@ const ( AnthropicFastModeBetaHeader = "fast-mode-2026-02-01" // AnthropicRedactThinkingBetaHeader is required for redacting thinking blocks in responses. AnthropicRedactThinkingBetaHeader = "redact-thinking-2026-02-12" + // AnthropicTaskBudgetsBetaHeader is required for output_config.task_budget (Opus 4.7+). + AnthropicTaskBudgetsBetaHeader = "task-budgets-2026-03-13" + // AnthropicEagerInputStreamingBetaHeader is required for eager_input_streaming + // on custom tools (streams input_json_delta before full args are determined). + // Per Table 20: GA on Anthropic/Bedrock/Vertex, Beta on Azure. + AnthropicEagerInputStreamingBetaHeader = "fine-grained-tool-streaming-2025-05-14" // AnthropicComputerUseBetaHeader is required for computer use (version-specific). // computer_20251124 (Opus 4.6, Sonnet 4.6, Opus 4.5) uses the newer beta header. @@ -59,6 +71,7 @@ const ( // Use these with strings.HasPrefix when filtering headers per provider, // so that future date bumps (e.g. structured-outputs-2025-12-15) are still matched. AnthropicAdvancedToolUseBetaHeaderPrefix = "advanced-tool-use-" + AnthropicToolExamplesBetaHeaderPrefix = "tool-examples-" AnthropicStructuredOutputsBetaHeaderPrefix = "structured-outputs-" AnthropicPromptCachingScopeBetaHeaderPrefix = "prompt-caching-scope-" AnthropicMCPClientBetaHeaderPrefix = "mcp-client-" @@ -67,64 +80,123 @@ const ( AnthropicContext1MBetaHeaderPrefix = "context-1m-" AnthropicFastModeBetaHeaderPrefix = "fast-mode-" AnthropicRedactThinkingBetaHeaderPrefix = "redact-thinking-" + AnthropicTaskBudgetsBetaHeaderPrefix = "task-budgets-" + AnthropicEagerInputStreamingBetaHeaderPrefix = "fine-grained-tool-streaming-" ) // ProviderFeatureSupport defines which Anthropic features a given provider supports. -// Source: https://docs.anthropic.com/en/build-with-claude/overview (March 2026) +// +// Authoritative sources (verified 2026-04-17): +// A = Anthropic feature-availability table: +// https://platform.claude.com/docs/en/build-with-claude/overview +// B-header = AWS Bedrock user guide beta-header list: +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +// B-platform = https://platform.claude.com/docs/en/build-with-claude/claude-on-amazon-bedrock +// V-platform = https://platform.claude.com/docs/en/build-with-claude/claude-on-vertex-ai +// Az-platform = https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry +// MCP-excl = MCP connector explicit Bedrock/Vertex exclusion: +// https://platform.claude.com/docs/en/agents-and-tools/mcp-connector +// Advisor-excl = Advisor tool Claude-API-only: +// https://platform.claude.com/docs/en/agents-and-tools/tool-use/advisor-tool type ProviderFeatureSupport struct { - WebSearch bool // web_search server tool - WebSearchDynamic bool // web_search_20260209 (dynamic filtering, requires code_execution) - WebFetch bool // web_fetch server tool - CodeExecution bool // code_execution server tool - ComputerUse bool // computer_use client tool - Bash bool // bash client tool - Memory bool // memory client tool - TextEditor bool // text_editor client tool - ToolSearch bool // tool_search server tool - MCP bool // MCP connector - AdvancedToolUse bool // advanced-tool-use (defer_loading, input_examples, allowed_callers) - StructuredOutputs bool // strict tool validation and output_format - PromptCachingScope bool // prompt caching scope - Compaction bool // server-side context compaction - ContextEditing bool // context editing (clear_tool_uses, clear_thinking) - FilesAPI bool // Files API - InterleavedThinking bool // interleaved thinking between tool calls - Skills bool // Agent Skills - Context1M bool // 1M context window beta (for Sonnet 4.5/4 only) - FastMode bool // fast mode (Opus 4.6 only, research preview) - RedactThinking bool // redact thinking blocks in responses + WebSearch bool // web_search server tool (cite: A) + WebSearchDynamic bool // web_search_20260209 dynamic filtering (cite: A) + WebFetch bool // web_fetch server tool (cite: A) + CodeExecution bool // code_execution server tool (cite: A) + ComputerUse bool // computer_use client tool (cite: A, B-header) + Bash bool // bash client tool (cite: A, B-header) + Memory bool // memory client tool — on Bedrock bundled under context-management-2025-06-27 (cite: A, B-header) + TextEditor bool // text_editor client tool (cite: A) + ToolSearch bool // tool_search server tool — tool-search-tool-2025-10-19 (cite: A, B-header) + MCP bool // MCP connector — explicit "not supported on Bedrock/Vertex" (cite: MCP-excl) + AdvancedToolUse bool // advanced-tool-use-2025-11-20 bundle: defer_loading + input_examples + allowed_callers (cite: A) + InputExamples bool // tool.input_examples standalone — tool-examples-2025-10-29. Bedrock supports this independently of the AdvancedToolUse bundle (cite: B-header). On Anthropic / Azure the bundle implicitly covers it. + StructuredOutputs bool // strict tool validation / output_format (cite: A) + PromptCachingScope bool // cache_control.scope — prompt-caching-scope-2026-01-05 (cite: A) + Compaction bool // compact_20260112 (cite: A, B-header) + ContextEditing bool // clear_tool_uses / clear_thinking (cite: A, B-header) + FilesAPI bool // files-api-2025-04-14, file_id source (cite: A) + InterleavedThinking bool // interleaved thinking between tool calls (cite: A, B-header; fails on non-allowlisted models on Bedrock/Vertex) + Skills bool // Agent Skills — container.skills object (cite: A) + ContainerBasic bool // Bare string-form container id — universally supported (cite: A) + Context1M bool // 1M context window — context-1m-2025-08-07 (cite: A) + FastMode bool // Opus 4.6 research preview — fast-mode-2026-02-01 (cite: A) + RedactThinking bool // redact-thinking-2026-02-12 (cite: A) — note Bedrock has its own "thinking encryption" (different mechanism) + TaskBudgets bool // output_config.task_budget — task-budgets-2026-03-13 (cite: A) + InferenceGeo bool // inference_geo field — Claude API only; Bedrock/Vertex/Azure use their own region-routing mechanisms (cite: A) + EagerInputStreaming bool // fine-grained-tool-streaming-2025-05-14 (cite: A, B-header) + AdvisorTool bool // advisor_tool_result block — Anthropic only (cite: Advisor-excl) FileSearch bool // file_search server tool (OpenAI-only) ImageGeneration bool // image_generation server tool (OpenAI-only) } // ProviderFeatures maps each provider to its supported Anthropic features. +// +// Every cell below is sourced from the docs named in ProviderFeatureSupport. +// "Not documented" in upstream docs is treated as unsupported here; if a user +// needs a pass-through, ExtraParams still works. var ProviderFeatures = map[schemas.ModelProvider]ProviderFeatureSupport{ + // Anthropic Claude API direct (cite: A across the board). schemas.Anthropic: { WebSearch: true, WebSearchDynamic: true, WebFetch: true, CodeExecution: true, ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, - MCP: true, AdvancedToolUse: true, StructuredOutputs: true, PromptCachingScope: true, + MCP: true, AdvancedToolUse: true, InputExamples: true, StructuredOutputs: true, PromptCachingScope: true, Compaction: true, ContextEditing: true, FilesAPI: true, - InterleavedThinking: true, Skills: true, Context1M: true, FastMode: true, - RedactThinking: true, + InterleavedThinking: true, Skills: true, ContainerBasic: true, Context1M: true, + FastMode: true, RedactThinking: true, TaskBudgets: true, + InferenceGeo: true, EagerInputStreaming: true, AdvisorTool: true, }, + // Google Vertex AI — cite: A (overview table) and V-platform. + // Notably NOT supported: MCP (MCP-excl), Skills/container.skills, + // InferenceGeo, FastMode, TaskBudgets, AdvisorTool, StructuredOutputs, + // PromptCachingScope (400 "unexpected beta header" per LiteLLM #19984), + // FilesAPI, WebFetch, CodeExecution, AdvancedToolUse, RedactThinking. schemas.Vertex: { - WebSearch: true, // only web_search_20250305 (basic), NOT dynamic filtering + WebSearch: true, // web search GA on Vertex per A; earlier code restricted to web_search_20250305 — A doesn't qualify ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, - Compaction: true, ContextEditing: true, - InterleavedThinking: true, Context1M: true, + ContainerBasic: true, + Compaction: true, + ContextEditing: true, + InterleavedThinking: true, // V-platform confirms; fails on non-allowlisted 4-series + Context1M: true, + EagerInputStreaming: true, // fine-grained-tool-streaming GA per A }, + // AWS Bedrock — cite: A + B-header (definitive beta-header list). + // Notably NOT supported per docs: MCP, Skills, FilesAPI, WebFetch, + // WebSearch, CodeExecution, FastMode, TaskBudgets, AdvisorTool, + // InferenceGeo, RedactThinking, AdvancedToolUse (full), PromptCachingScope. schemas.Bedrock: { ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, - StructuredOutputs: true, Compaction: true, ContextEditing: true, - InterleavedThinking: true, Context1M: true, + ContainerBasic: true, + // StructuredOutputs: kept true to match pre-existing behavior and the + // provider_feature_support_test.go assertion, but NEITHER B-header + // NOR B-platform upstream docs document strict tool validation / + // output_format on Bedrock. Needs live verification. If Bedrock's + // Converse API actually rejects `strict: true`, flip this to false + // and update the corresponding test assertion. + StructuredOutputs: true, + Compaction: true, // compact-2026-01-12 per B-header + ContextEditing: true, // context-management-2025-06-27 per B-header (bundles memory) + InterleavedThinking: true, // per B-header; model-allowlisted + Context1M: true, // Opus 4.6 / Sonnet 4.6 per A + EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 per B-header + InputExamples: true, // tool-examples-2025-10-29 per B-header (standalone; Bedrock doesn't accept the full advanced-tool-use-2025-11-20 bundle — see TestFilterBetaHeadersForProvider) + // AdvancedToolUse intentionally OFF on Bedrock. The bundle header + // (advanced-tool-use-2025-11-20) is not listed in B-header; only the + // narrow tool-examples-2025-10-29 header is, gated via InputExamples above. }, + // Microsoft Azure AI Foundry — cite: A (most features azureAiBeta) + + // Az-platform ("supports most of Claude's features"). Excluded per + // Az-platform: Admin API, Models API, Message Batch API (not in scope). schemas.Azure: { WebSearch: true, WebSearchDynamic: true, WebFetch: true, CodeExecution: true, ComputerUse: true, Bash: true, Memory: true, TextEditor: true, ToolSearch: true, - MCP: true, AdvancedToolUse: true, StructuredOutputs: true, PromptCachingScope: true, + MCP: true, AdvancedToolUse: true, InputExamples: true, StructuredOutputs: true, PromptCachingScope: true, Compaction: true, ContextEditing: true, FilesAPI: true, - InterleavedThinking: true, Skills: true, Context1M: true, - RedactThinking: true, + InterleavedThinking: true, Skills: true, ContainerBasic: true, Context1M: true, + RedactThinking: true, TaskBudgets: true, + EagerInputStreaming: true, + // FastMode, InferenceGeo, AdvisorTool — not in Az-platform; leave off. }, } @@ -156,11 +228,88 @@ func (req *AnthropicTextRequest) IsStreamingRequested() bool { return req.Stream != nil && *req.Stream } -// AnthropicOutputConfig represents the GA structured outputs config (output_config.format) -// and the effort parameter (output_config.effort) for controlling token spending. +// AnthropicTaskBudget represents an advisory token budget for a full agentic loop (output_config.task_budget). +// The model sees a running countdown and uses it to prioritize work and finish gracefully. +// Requires beta header "task-budgets-2026-03-13". Minimum total: 20 000 tokens. +// This is advisory, not a hard cap — use max_tokens as the per-request hard ceiling. +type AnthropicTaskBudget struct { + Type string `json:"type"` // always "tokens" + Total int `json:"total"` // total advisory token budget across the agentic loop + Remaining *int `json:"remaining,omitempty"` // optional; tracks remaining tokens for client-side compaction +} + +// AnthropicOutputConfig represents the GA structured outputs config (output_config.format), +// the effort parameter (output_config.effort), and the task budget (output_config.task_budget). type AnthropicOutputConfig struct { - Format json.RawMessage `json:"format,omitempty"` - Effort *string `json:"effort,omitempty"` // "low", "medium", "high", "max" (Opus 4.5+) + Format json.RawMessage `json:"format,omitempty"` // JSON schema for structured outputs + Effort *string `json:"effort,omitempty"` // "low" | "medium" | "high" | "xhigh" | "max" + TaskBudget *AnthropicTaskBudget `json:"task_budget,omitempty"` // advisory token budget; requires task-budgets-2026-03-13 beta header +} + +// AnthropicContainerSkill represents a single skill attached to a container. +// Requires beta header "skills-2025-10-02". +type AnthropicContainerSkill struct { + SkillID string `json:"skill_id"` // Unique identifier for the skill + Type string `json:"type"` // "anthropic" (built-in) | "custom" (user-defined) + Version *string `json:"version,omitempty"` // Optional version pin +} + +// AnthropicContainerObject represents the object form of the container field: +// { id?: string, skills?: [...] }. The skills[] array is gated by the +// skills-2025-10-02 beta header; a bare id-only container is GA. +type AnthropicContainerObject struct { + ID *string `json:"id,omitempty"` + Skills []AnthropicContainerSkill `json:"skills,omitempty"` +} + +// AnthropicContainer is the "container" field on AnthropicMessageRequest. +// Per Anthropic docs it can be either a bare string (container id) or an +// object with id+skills[]. The object-with-skills form requires beta header +// "skills-2025-10-02"; the string form is GA. +// Source: https://platform.claude.com/docs/en/api/messages/create +type AnthropicContainer struct { + ContainerStr *string + ContainerObject *AnthropicContainerObject +} + +// MarshalJSON encodes the union as either a raw string or the object form. +func (c AnthropicContainer) MarshalJSON() ([]byte, error) { + if c.ContainerStr != nil && c.ContainerObject != nil { + return nil, fmt.Errorf("both ContainerStr and ContainerObject are set; only one should be non-nil") + } + if c.ContainerStr != nil { + return providerUtils.MarshalSorted(*c.ContainerStr) + } + if c.ContainerObject != nil { + return providerUtils.MarshalSorted(c.ContainerObject) + } + return providerUtils.MarshalSorted(nil) +} + +// UnmarshalJSON decodes either a string or the object form into the union. +// Clears the inactive arm on each success so a reused struct never ends up +// with both fields populated (which MarshalJSON rejects). Explicitly handles +// JSON null. Matches the ChatContainer / ChatToolChoice union patterns. +func (c *AnthropicContainer) UnmarshalJSON(data []byte) error { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + c.ContainerStr = nil + c.ContainerObject = nil + return nil + } + var s string + if err := sonic.Unmarshal(data, &s); err == nil { + c.ContainerStr = &s + c.ContainerObject = nil + return nil + } + var obj AnthropicContainerObject + if err := sonic.Unmarshal(data, &obj); err == nil { + c.ContainerStr = nil + c.ContainerObject = &obj + return nil + } + return fmt.Errorf("container field is neither a string nor a container object") } // AnthropicMessageRequest represents an Anthropic messages API request @@ -186,6 +335,7 @@ type AnthropicMessageRequest struct { ServiceTier *string `json:"service_tier,omitempty"` // "auto" or "standard_only" InferenceGeo *string `json:"inference_geo,omitempty"` // the geographic region for inference processing. If not specified, the workspace's default_inference_geo is used. ContextManagement *ContextManagement `json:"context_management,omitempty"` + Container *AnthropicContainer `json:"container,omitempty"` // string id OR object with skills[]; skills require skills-2025-10-02 beta // Extra params for advanced use cases ExtraParams map[string]interface{} `json:"-"` @@ -212,8 +362,9 @@ type AnthropicMetaData struct { } type AnthropicThinking struct { - Type string `json:"type"` // "enabled" or "disabled" - BudgetTokens *int `json:"budget_tokens,omitempty"` + Type string `json:"type"` // "enabled", "disabled", or "adaptive" + BudgetTokens *int `json:"budget_tokens,omitempty"` // Only for type "enabled" (not supported on Opus 4.7+) + Display *string `json:"display,omitempty"` // "summarized" | "omitted" — controls whether thinking content appears in the response (Opus 4.7+) } type ContextManagementEditType string @@ -461,6 +612,7 @@ var anthropicMessageRequestKnownFields = map[string]bool{ "service_tier": true, "inference_geo": true, "context_management": true, + "container": true, "extra_params": true, "fallbacks": true, } @@ -685,54 +837,205 @@ func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { type AnthropicContentBlockType string const ( - AnthropicContentBlockTypeText AnthropicContentBlockType = "text" - AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" - AnthropicContentBlockTypeDocument AnthropicContentBlockType = "document" - AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" - AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" - AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" - AnthropicContentBlockTypeWebSearchToolResult AnthropicContentBlockType = "web_search_tool_result" - AnthropicContentBlockTypeWebSearchToolResultError AnthropicContentBlockType = "web_search_tool_result_error" - AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" - AnthropicContentBlockTypeWebFetchToolResult AnthropicContentBlockType = "web_fetch_tool_result" - AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" - AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" - AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" - AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" - AnthropicContentBlockTypeCompaction AnthropicContentBlockType = "compaction" + AnthropicContentBlockTypeText AnthropicContentBlockType = "text" + AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" + AnthropicContentBlockTypeDocument AnthropicContentBlockType = "document" + AnthropicContentBlockTypeSearchResult AnthropicContentBlockType = "search_result" + AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" + AnthropicContentBlockTypeServerToolUse AnthropicContentBlockType = "server_tool_use" + AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" + AnthropicContentBlockTypeWebSearchToolResult AnthropicContentBlockType = "web_search_tool_result" + AnthropicContentBlockTypeWebSearchToolResultError AnthropicContentBlockType = "web_search_tool_result_error" + AnthropicContentBlockTypeWebSearchResult AnthropicContentBlockType = "web_search_result" + AnthropicContentBlockTypeWebFetchToolResult AnthropicContentBlockType = "web_fetch_tool_result" + AnthropicContentBlockTypeCodeExecutionToolResult AnthropicContentBlockType = "code_execution_tool_result" + AnthropicContentBlockTypeBashCodeExecutionToolResult AnthropicContentBlockType = "bash_code_execution_tool_result" + AnthropicContentBlockTypeTextEditorCodeExecutionToolResult AnthropicContentBlockType = "text_editor_code_execution_tool_result" + AnthropicContentBlockTypeToolSearchToolResult AnthropicContentBlockType = "tool_search_tool_result" + AnthropicContentBlockTypeToolReference AnthropicContentBlockType = "tool_reference" + AnthropicContentBlockTypeContainerUpload AnthropicContentBlockType = "container_upload" + AnthropicContentBlockTypeAdvisorToolResult AnthropicContentBlockType = "advisor_tool_result" + AnthropicContentBlockTypeMCPToolUse AnthropicContentBlockType = "mcp_tool_use" + AnthropicContentBlockTypeMCPToolResult AnthropicContentBlockType = "mcp_tool_result" + AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" + AnthropicContentBlockTypeRedactedThinking AnthropicContentBlockType = "redacted_thinking" + AnthropicContentBlockTypeCompaction AnthropicContentBlockType = "compaction" ) -// AnthropicContentBlock represents content in Anthropic message format +// AnthropicToolCallerType identifies which agentic caller produced a tool +// invocation. Appears on tool_use, server_tool_use, and every *_tool_result +// block per Anthropic docs. +// Source: https://platform.claude.com/docs/en/api/beta/messages/create +type AnthropicToolCallerType string + +const ( + AnthropicToolCallerTypeDirect AnthropicToolCallerType = "direct" + AnthropicToolCallerTypeCodeExecution20250825 AnthropicToolCallerType = "code_execution_20250825" + AnthropicToolCallerTypeCodeExecution20260120 AnthropicToolCallerType = "code_execution_20260120" +) + +// AnthropicToolCaller represents the "caller" union on tool-use and +// tool-result blocks. For the two code-execution variants, ToolID is required +// and identifies the upstream server tool that invoked the nested tool. +type AnthropicToolCaller struct { + Type AnthropicToolCallerType `json:"type"` + ToolID *string `json:"tool_id,omitempty"` // Required for code_execution_* caller types +} + +// AnthropicContentBlock represents content in Anthropic message format. +// This is a fat struct: every optional field here is used by at least one +// block type. Consult Anthropic's content-block docs before adding a field +// so we reuse existing ones where semantics align. type AnthropicContentBlock struct { - Type AnthropicContentBlockType `json:"type"` // "text", "image", "document", "tool_use", "tool_result", "thinking" - Text *string `json:"text,omitempty"` // For text content - Thinking *string `json:"thinking,omitempty"` // For thinking content - Signature *string `json:"signature,omitempty"` // For signature content - Data *string `json:"data,omitempty"` // For data content (encrypted data for redacted thinking, signature does not come with this) - ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content - ID *string `json:"id,omitempty"` // For tool_use content - Name *string `json:"name,omitempty"` // For tool_use content - Input json.RawMessage `json:"input,omitempty"` // For tool_use content (json.RawMessage preserves key ordering for prompt caching) - ServerName *string `json:"server_name,omitempty"` // For mcp_tool_use content - Content *AnthropicContent `json:"content,omitempty"` // For tool_result content - IsError *bool `json:"is_error,omitempty"` // For tool_result content, indicates error state - Source *AnthropicSource `json:"source,omitempty"` // For image/document content - CacheControl *schemas.CacheControl `json:"cache_control,omitempty"` // For cache control content - Citations *AnthropicCitations `json:"citations,omitempty"` // For document content - Context *string `json:"context,omitempty"` // For document content - Title *string `json:"title,omitempty"` // For document content - URL *string `json:"url,omitempty"` // For web_search_result content - EncryptedContent *string `json:"encrypted_content,omitempty"` // For web_search_result content - PageAge *string `json:"page_age,omitempty"` // For web_search_result content - ErrorCode *string `json:"error_code,omitempty"` // For web_search_tool_result_error content -} - -// AnthropicSource represents image or document source in Anthropic format + Type AnthropicContentBlockType `json:"type"` // Discriminator + Text *string `json:"text,omitempty"` // text block; also "advisor_result" variant + Thinking *string `json:"thinking,omitempty"` // thinking block + Signature *string `json:"signature,omitempty"` // thinking block signature + Data *string `json:"data,omitempty"` // redacted_thinking encrypted data (no signature) + ToolUseID *string `json:"tool_use_id,omitempty"` // tool_result, *_tool_result blocks + ID *string `json:"id,omitempty"` // tool_use, server_tool_use, mcp_tool_use + Name *string `json:"name,omitempty"` // tool_use, server_tool_use; also reused for tool_reference's tool_name via ToolName + Input json.RawMessage `json:"input,omitempty"` // tool_use / server_tool_use (json.RawMessage preserves key ordering for prompt caching) + ServerName *string `json:"server_name,omitempty"` // mcp_tool_use + Content *AnthropicContent `json:"content,omitempty"` // tool_result, *_tool_result; inner structured content or string + IsError *bool `json:"is_error,omitempty"` // tool_result, *_tool_result + Source *AnthropicBlockSource `json:"source,omitempty"` // image, document (SourceObj) or search_result (SourceStr) — union type + CacheControl *schemas.CacheControl `json:"cache_control,omitempty"` // any block + Citations *AnthropicCitations `json:"citations,omitempty"` // text, document, search_result (request config) or response citations array + Context *string `json:"context,omitempty"` // document + Title *string `json:"title,omitempty"` // document, search_result, web_search_result + URL *string `json:"url,omitempty"` // web_search_result, web_fetch_result + EncryptedContent *string `json:"encrypted_content,omitempty"` // web_search_result, advisor_redacted_result, compaction + PageAge *string `json:"page_age,omitempty"` // web_search_result + ErrorCode *string `json:"error_code,omitempty"` // any *_tool_result_error variant + Caller *AnthropicToolCaller `json:"caller,omitempty"` // tool_use, server_tool_use, every *_tool_result block + + // search_result block: the API uses the literal key "source" with a plain + // string value, which collides with the existing Source *AnthropicSource + // field (object form, used by image/document). Supporting both requires + // either (a) a string-or-object union type for Source, or (b) full custom + // Marshal/Unmarshal on AnthropicContentBlock. Deferred until we decide the + // representation — search_result block enum is present above but its + // source string has no typed slot yet. Callers needing it can use + // ExtraParams pass-through on the request side in the meantime. + + // code_execution_tool_result / bash_code_execution_tool_result result-variant fields + Stdout *string `json:"stdout,omitempty"` + Stderr *string `json:"stderr,omitempty"` + ReturnCode *int `json:"return_code,omitempty"` + EncryptedStdout *string `json:"encrypted_stdout,omitempty"` + + // text_editor_code_execution_tool_result variants + FileType *string `json:"file_type,omitempty"` // view_result: "text"|"image"|"pdf" + StartLine *int `json:"start_line,omitempty"` // view_result + NumLines *int `json:"num_lines,omitempty"` // view_result + TotalLines *int `json:"total_lines,omitempty"` // view_result + IsFileUpdate *bool `json:"is_file_update,omitempty"` // create_result + OldStart *int `json:"old_start,omitempty"` // str_replace_result + OldLines *int `json:"old_lines,omitempty"` // str_replace_result + NewStart *int `json:"new_start,omitempty"` // str_replace_result + NewLines *int `json:"new_lines,omitempty"` // str_replace_result + Lines []string `json:"lines,omitempty"` // str_replace_result + ErrorMessage *string `json:"error_message,omitempty"` // text_editor error variant + + // tool_search_tool_result success variant + ToolReferences []AnthropicContentBlock `json:"tool_references,omitempty"` // tool_search_tool_search_result (array of tool_reference blocks) + + // tool_reference block — tool_name field on the block itself + ToolName *string `json:"tool_name,omitempty"` + + // container_upload block + web_fetch_result inner file_id reference + FileID *string `json:"file_id,omitempty"` + + // web_fetch_tool_result / web_fetch_result inner retrieval timestamp + RetrievedAt *string `json:"retrieved_at,omitempty"` +} + +// AnthropicSource represents image or document source in Anthropic format. +// +// Per docs (https://platform.claude.com/docs/en/api/messages/create) the +// documented type values and their carrying fields are: +// - "base64" → MediaType + Data +// - "url" → URL +// - "text" → MediaType ("text/plain") + Data +// - "content_block" → Content (nested string OR array of inner blocks); +// recursive ContentBlockSource used inside DocumentBlockParam +// - "file" → FileID (requires files-api-2025-04-14 beta) +// +// The struct is a superset — only the fields relevant to Type should be set +// at a time. type AnthropicSource struct { - Type string `json:"type"` // "base64", "url", "text", "content_block" - MediaType *string `json:"media_type,omitempty"` // "image/jpeg", "image/png", "application/pdf", etc. - Data *string `json:"data,omitempty"` // Base64-encoded data (for base64 type) - URL *string `json:"url,omitempty"` // URL (for url type) + Type string `json:"type"` // "base64" | "url" | "text" | "content" | "content_block" (alias) | "file" + MediaType *string `json:"media_type,omitempty"` // "image/jpeg", "image/png", "application/pdf", etc. + Data *string `json:"data,omitempty"` // Base64-encoded data (base64 type) or text payload (text type) + URL *string `json:"url,omitempty"` // URL (url type) + FileID *string `json:"file_id,omitempty"` // File ID (file type; requires files-api-2025-04-14 beta) + Content json.RawMessage `json:"content,omitempty"` // For content_block type: nested content — string OR array of inner blocks (TextBlockParam / ImageBlockParam). json.RawMessage preserves exact bytes for prompt caching. +} + +// AnthropicBlockSource is the union "source" field on a content block. +// +// Anthropic's API uses the literal JSON key "source" for two incompatible +// shapes depending on which block the key appears on: +// +// - On `image` / `document` blocks: an OBJECT describing the source +// (type + media_type + data/url/file_id). Modeled by AnthropicSource. +// - On `search_result` blocks: a plain STRING identifier (URL/path). +// +// This union wrapper lets AnthropicContentBlock carry either shape under +// the single "source" JSON key. +// +// Docs: +// - https://platform.claude.com/docs/en/api/messages/create (ImageBlockParam, DocumentBlockParam) +// - https://platform.claude.com/docs/en/api/beta/messages/create (SearchResultBlockParam) +type AnthropicBlockSource struct { + SourceStr *string // search_result: plain string (URL, path, identifier) + SourceObj *AnthropicSource // image / document: object form +} + +// MarshalJSON emits either the string or the object form directly (unwrapped). +// Matches the union-type idiom used by AnthropicCitations, AnthropicContainer, +// and CompactManagementEditTypeAndValue. +func (s AnthropicBlockSource) MarshalJSON() ([]byte, error) { + if s.SourceStr != nil && s.SourceObj != nil { + return nil, fmt.Errorf("both SourceStr and SourceObj are set; only one should be non-nil") + } + if s.SourceStr != nil { + return providerUtils.MarshalSorted(*s.SourceStr) + } + if s.SourceObj != nil { + return providerUtils.MarshalSorted(s.SourceObj) + } + return providerUtils.MarshalSorted(nil) +} + +// UnmarshalJSON decodes either the string or the object form into the union. +// Matches AnthropicCitations.UnmarshalJSON: sonic-decode into each variant, +// first success wins. +// UnmarshalJSON decodes either the string form (search_result blocks) or the +// object form (image/document blocks) into the union. Clears the inactive +// arm on each success so a reused struct never ends up with both fields +// populated (which MarshalJSON rejects). Explicitly handles JSON null. +func (s *AnthropicBlockSource) UnmarshalJSON(data []byte) error { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + s.SourceStr = nil + s.SourceObj = nil + return nil + } + var str string + if err := sonic.Unmarshal(data, &str); err == nil { + s.SourceStr = &str + s.SourceObj = nil + return nil + } + var obj AnthropicSource + if err := sonic.Unmarshal(data, &obj); err == nil { + s.SourceStr = nil + s.SourceObj = &obj + return nil + } + return fmt.Errorf("source field is neither a string nor an AnthropicSource object") } type AnthropicCitationType string @@ -796,7 +1099,7 @@ func (ac *AnthropicCitations) MarshalJSON() ([]byte, error) { ac.TextCitations = nil } if ac.Config != nil && ac.TextCitations != nil { - return nil, fmt.Errorf("AnthropicCitations: both Config and TextCitations are set; only one should be non-nil") + return nil, fmt.Errorf("both Config and TextCitations are set; only one should be non-nil") } if ac.Config != nil { @@ -840,7 +1143,9 @@ type AnthropicToolType string const ( AnthropicToolTypeCustom AnthropicToolType = "custom" + AnthropicToolTypeBash20241022 AnthropicToolType = "bash_20241022" // computer-use-2024-10-22 beta AnthropicToolTypeBash20250124 AnthropicToolType = "bash_20250124" + AnthropicToolTypeComputer20241022 AnthropicToolType = "computer_20241022" // computer-use-2024-10-22 beta AnthropicToolTypeComputer20250124 AnthropicToolType = "computer_20250124" AnthropicToolTypeComputer20251124 AnthropicToolType = "computer_20251124" // for claude-opus-4.5, claude-opus-4.6, claude-sonnet-4.6 AnthropicToolTypeTextEditor20250124 AnthropicToolType = "text_editor_20250124" @@ -908,10 +1213,19 @@ type AnthropicToolWebSearch struct { } type AnthropicToolWebFetch struct { - MaxUses *int `json:"max_uses,omitempty"` - AllowedDomains []string `json:"allowed_domains,omitempty"` - BlockedDomains []string `json:"blocked_domains,omitempty"` - MaxContentTokens *int `json:"max_content_tokens,omitempty"` + MaxUses *int `json:"max_uses,omitempty"` + AllowedDomains []string `json:"allowed_domains,omitempty"` + BlockedDomains []string `json:"blocked_domains,omitempty"` + MaxContentTokens *int `json:"max_content_tokens,omitempty"` + Citations *AnthropicCitations `json:"citations,omitempty"` // {enabled: bool} — toggles citation emission on fetched documents + UseCache *bool `json:"use_cache,omitempty"` // web_fetch_20260309+ only — enables server-side page cache +} + +// AnthropicToolTextEditor holds fields specific to the text_editor tool +// variants. Only text_editor_20250728 (and later) honours max_characters +// as a view-truncation cap. +type AnthropicToolTextEditor struct { + MaxCharacters *int `json:"max_characters,omitempty"` // text_editor_20250728+ only } // AnthropicToolInputExample represents an input example for a tool (beta feature) @@ -922,19 +1236,21 @@ type AnthropicToolInputExample struct { // AnthropicTool represents a tool in Anthropic format type AnthropicTool struct { - Name string `json:"name"` - Type *AnthropicToolType `json:"type,omitempty"` - Description *string `json:"description,omitempty"` - InputSchema *schemas.ToolFunctionParameters `json:"input_schema,omitempty"` - CacheControl *schemas.CacheControl `json:"cache_control,omitempty"` - DeferLoading *bool `json:"defer_loading,omitempty"` // Beta: defer loading of tool definition - Strict *bool `json:"strict,omitempty"` // Whether to enforce strict parameter validation - AllowedCallers []string `json:"allowed_callers,omitempty"` // Beta: which callers can use this tool - InputExamples []AnthropicToolInputExample `json:"input_examples,omitempty"` // Beta: example inputs for the tool + Name string `json:"name"` + Type *AnthropicToolType `json:"type,omitempty"` + Description *string `json:"description,omitempty"` + InputSchema *schemas.ToolFunctionParameters `json:"input_schema,omitempty"` + CacheControl *schemas.CacheControl `json:"cache_control,omitempty"` + DeferLoading *bool `json:"defer_loading,omitempty"` // Beta: defer loading of tool definition + Strict *bool `json:"strict,omitempty"` // Whether to enforce strict parameter validation + AllowedCallers []string `json:"allowed_callers,omitempty"` // Beta: which callers can use this tool + InputExamples []AnthropicToolInputExample `json:"input_examples,omitempty"` // Beta: example inputs for the tool + EagerInputStreaming *bool `json:"eager_input_streaming,omitempty"` // Custom tools only; beta fine-grained-tool-streaming-2025-05-14 *AnthropicToolComputerUse *AnthropicToolWebSearch *AnthropicToolWebFetch + *AnthropicToolTextEditor // MCP toolset (mcp-client-2025-11-20 format) — embedded when Type is nil and MCPToolset is set MCPToolset *AnthropicMCPToolsetTool `json:"-"` // Serialized via custom MarshalJSON @@ -1248,7 +1564,7 @@ type AnthropicFileDeleteResponse struct { } // ToBifrostFileUploadResponse converts an Anthropic file response to Bifrost file upload response. -func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Type, @@ -1259,9 +1575,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1277,7 +1591,7 @@ func (r *AnthropicFileResponse) ToBifrostFileUploadResponse(providerName schemas } // ToBifrostFileRetrieveResponse converts an Anthropic file response to Bifrost file retrieve response. -func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { +func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse { resp := &schemas.BifrostFileRetrieveResponse{ ID: r.ID, Object: r.Type, @@ -1288,9 +1602,7 @@ func (r *AnthropicFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: schemas.FileStatusProcessed, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1320,4 +1632,4 @@ func parseAnthropicFileTimestamp(timestamp string) int64 { // AnthropicCountTokensResponse models the payload returned by Anthropic's count tokens endpoint. type AnthropicCountTokensResponse struct { InputTokens int `json:"input_tokens"` -} +} \ No newline at end of file diff --git a/core/providers/anthropic/utils.go b/core/providers/anthropic/utils.go index 1dbbbedb08..db0ccada96 100644 --- a/core/providers/anthropic/utils.go +++ b/core/providers/anthropic/utils.go @@ -14,6 +14,77 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +// anthropicToolTypePrefixToFeature maps Anthropic server-tool type prefixes +// to the corresponding ProviderFeatureSupport flag. Mirrors the structure of +// betaHeaderPrefixToFeature (defined later in this file) so tool-type gating +// and beta-header gating share the same shape. +// +// Prefix-based so future version bumps (e.g. web_search_20261231) flow +// through without a code change. Exact-match types (currently just +// "mcp_toolset") are handled separately. +var anthropicToolTypePrefixToFeature = map[string]func(ProviderFeatureSupport) bool{ + "web_search_": func(f ProviderFeatureSupport) bool { return f.WebSearch }, + "web_fetch_": func(f ProviderFeatureSupport) bool { return f.WebFetch }, + "code_execution_": func(f ProviderFeatureSupport) bool { return f.CodeExecution }, + "computer_": func(f ProviderFeatureSupport) bool { return f.ComputerUse }, + "bash_": func(f ProviderFeatureSupport) bool { return f.Bash }, + "memory_": func(f ProviderFeatureSupport) bool { return f.Memory }, + "text_editor_": func(f ProviderFeatureSupport) bool { return f.TextEditor }, + "tool_search_tool_": func(f ProviderFeatureSupport) bool { return f.ToolSearch }, +} + +// isAnthropicServerToolSupported returns whether the given Anthropic server-tool +// type string is supported by the provider's ProviderFeatureSupport. Unknown +// types return true (forward-compat: let the provider reject if truly invalid +// rather than Bifrost dropping a tool Anthropic has just added). +func isAnthropicServerToolSupported(toolType string, features ProviderFeatureSupport) bool { + // Exact-match types first. + if toolType == "mcp_toolset" { + return features.MCP + } + // Prefix match for versioned types. + for prefix, check := range anthropicToolTypePrefixToFeature { + if strings.HasPrefix(toolType, prefix) { + return check(features) + } + } + return true +} + +// ValidateChatToolsForProvider is the chat-path mirror of +// ValidateToolsForProvider. It partitions []schemas.ChatTool into a keep-set +// (function/custom tools + server tools supported on the target provider) +// and a dropped-set (server-tool Type strings the provider doesn't support +// per ProviderFeatures). +// +// Does NOT mutate its input. Callers decide the policy (silent strip vs +// fail-fast). The Bedrock ChatCompletion path uses silent strip so the +// request still reaches the provider without the unsupported tool; the model +// responds with a prose completion instead of tool use. +// +// Unknown providers keep all tools (safe default for custom providers), +// matching ValidateToolsForProvider. +func ValidateChatToolsForProvider(tools []schemas.ChatTool, provider schemas.ModelProvider) (keep []schemas.ChatTool, dropped []string) { + features, ok := ProviderFeatures[provider] + if !ok { + return tools, nil + } + for _, tool := range tools { + // Function/custom tools are universal — always keep. + if tool.Function != nil || tool.Custom != nil { + keep = append(keep, tool) + continue + } + t := string(tool.Type) + if isAnthropicServerToolSupported(t, features) { + keep = append(keep, tool) + } else { + dropped = append(dropped, t) + } + } + return keep, dropped +} + // ValidateToolsForProvider checks if all tools in the request are supported by the given provider. // Returns an error for the first unsupported tool found. func ValidateToolsForProvider(tools []schemas.ResponsesTool, provider schemas.ModelProvider) error { @@ -90,6 +161,448 @@ var ( } ) +// stripUnsupportedAnthropicFields removes request-level and tool-level fields +// that the target Anthropic-family provider does not support, according to the +// ProviderFeatures map (types.go). Tool-type validation (fail-closed) is handled +// separately by ValidateToolsForProvider; this helper handles request-level +// fields (strip silently, since they're additive enhancements). +// +// Mutates req in place. Safe to call multiple times. +func stripUnsupportedAnthropicFields(req *AnthropicMessageRequest, provider schemas.ModelProvider, model string) { + if req == nil { + return + } + features, ok := ProviderFeatures[provider] + if !ok { + // Unknown provider — safe default: don't strip anything. + return + } + + // Request-level fields gated by ProviderFeatures flags. + if req.Container != nil { + // Skills form (object with skills[]) is beta-gated; bare string id is universal. + // Intent signal: non-empty skills = caller explicitly wants skills; empty + // skills:[] = likely caller oversight we can silently correct. + hasSkills := req.Container.ContainerObject != nil && len(req.Container.ContainerObject.Skills) > 0 + // Strip an explicit empty or non-empty skills array on Skills=false + // providers. omitempty already handles this at serialize time for empty + // arrays, but we clear it explicitly so hasSkills-based decisions below + // and raw-path parity both stay correct. + if !features.Skills && req.Container.ContainerObject != nil && req.Container.ContainerObject.Skills != nil { + req.Container.ContainerObject.Skills = nil + } + switch { + case hasSkills && !features.Skills: + // Caller wanted non-empty skills but provider doesn't support them. + req.Container = nil + case !hasSkills && !features.ContainerBasic: + req.Container = nil + } + } + if len(req.MCPServers) > 0 && !features.MCP { + req.MCPServers = nil + } + // Speed is both provider-gated (FastMode flag) and model-gated + // (Opus 4.6 only per SupportsFastMode). Strip if either gate fails — + // Anthropic's API rejects speed:"fast" on non-Opus-4.6 models with a 400. + if req.Speed != nil && (!features.FastMode || !SupportsFastMode(model)) { + req.Speed = nil + } + if req.OutputConfig != nil && req.OutputConfig.TaskBudget != nil && !features.TaskBudgets { + req.OutputConfig.TaskBudget = nil + // Clean up an empty OutputConfig so it doesn't serialize as {} + if req.OutputConfig.Format == nil && req.OutputConfig.Effort == nil { + req.OutputConfig = nil + } + } + if req.InferenceGeo != nil && !features.InferenceGeo { + req.InferenceGeo = nil + } + // cache_control.scope — strip on providers without PromptCachingScope + // support at every slot scope can live: top-level request, tools, system + // blocks, and message content blocks. Vertex additionally uses the + // marshal-time SetStripCacheControlScope mechanism (vertex/utils.go:104, + // types.go MarshalJSON); after this strip runs, that marshal-time pass + // becomes a safe no-op for Vertex (nothing left to strip). + if !features.PromptCachingScope { + // Top-level. + if req.CacheControl != nil && req.CacheControl.Scope != nil { + req.CacheControl.Scope = nil + // If scope was the only meaningful field, drop the whole CacheControl + // so we don't serialize an empty object. + if req.CacheControl.TTL == nil && req.CacheControl.Type == "" { + req.CacheControl = nil + } + } + // Per-tool cache_control.scope. + for i := range req.Tools { + if req.Tools[i].CacheControl != nil && req.Tools[i].CacheControl.Scope != nil { + req.Tools[i].CacheControl.Scope = nil + // Drop the parent if scope was the only meaningful field. + if req.Tools[i].CacheControl.TTL == nil && req.Tools[i].CacheControl.Type == "" { + req.Tools[i].CacheControl = nil + } + } + } + // System block scopes. + if req.System != nil { + for i := range req.System.ContentBlocks { + if req.System.ContentBlocks[i].CacheControl != nil && req.System.ContentBlocks[i].CacheControl.Scope != nil { + req.System.ContentBlocks[i].CacheControl.Scope = nil + if req.System.ContentBlocks[i].CacheControl.TTL == nil && req.System.ContentBlocks[i].CacheControl.Type == "" { + req.System.ContentBlocks[i].CacheControl = nil + } + } + } + } + // Message block scopes. + for mi := range req.Messages { + for ci := range req.Messages[mi].Content.ContentBlocks { + cc := req.Messages[mi].Content.ContentBlocks[ci].CacheControl + if cc != nil && cc.Scope != nil { + cc.Scope = nil + if cc.TTL == nil && cc.Type == "" { + req.Messages[mi].Content.ContentBlocks[ci].CacheControl = nil + } + } + } + } + } + if req.ContextManagement != nil { + // Gate edits by their type — compaction vs context-editing flags. + kept := make([]ContextManagementEdit, 0, len(req.ContextManagement.Edits)) + for _, edit := range req.ContextManagement.Edits { + switch edit.Type { + case ContextManagementEditTypeCompact: + if features.Compaction { + kept = append(kept, edit) + } + case ContextManagementEditTypeClearToolUses, ContextManagementEditTypeClearThinking: + if features.ContextEditing { + kept = append(kept, edit) + } + default: + // Unknown edit type — keep and let upstream reject. + kept = append(kept, edit) + } + } + if len(kept) == 0 { + req.ContextManagement = nil + } else { + req.ContextManagement.Edits = kept + } + } + + // Tool-level flags — strip per-tool without dropping the tool itself. + for i := range req.Tools { + tool := &req.Tools[i] + if tool.DeferLoading != nil && !features.AdvancedToolUse { + tool.DeferLoading = nil + } + if len(tool.AllowedCallers) > 0 && !features.AdvancedToolUse { + tool.AllowedCallers = nil + } + // InputExamples has its own feature flag (InputExamples) because + // Bedrock supports the tool-examples-2025-10-29 header standalone — + // without the full advanced-tool-use-2025-11-20 bundle. On Anthropic + // and Azure, the bundle flag (AdvancedToolUse) is also set, so either + // gate would work there. + if len(tool.InputExamples) > 0 && !features.InputExamples { + tool.InputExamples = nil + } + if tool.EagerInputStreaming != nil && !features.EagerInputStreaming { + tool.EagerInputStreaming = nil + } + if tool.Strict != nil && !features.StructuredOutputs { + tool.Strict = nil + } + } +} + +// stripUnsupportedFieldsFromRawBody is the raw-JSON equivalent of +// StripUnsupportedAnthropicFields. It mutates the request body bytes using +// sjson/gjson (preserving key order for prompt caching) so the raw-body +// passthrough path has behavioural parity with the typed conversion path. +// +// Scope: every field the typed helper handles. +// - top-level: speed (provider + model gated), container (.skills gated by +// features.Skills, bare string by features.ContainerBasic), mcp_servers, +// inference_geo, cache_control.scope, output_config.task_budget, +// context_management.edits[] (gated per edit type). +// - nested: tool.CacheControl.Scope, system block scopes, message block +// scopes (all stripped when !features.PromptCachingScope). +// - per-tool: defer_loading, allowed_callers (AdvancedToolUse bundle), +// input_examples (narrow InputExamples flag), eager_input_streaming +// (EagerInputStreaming), strict (StructuredOutputs). +// +// Unknown providers: safe default — no stripping (parity with the typed helper). +// Unknown edit types in context_management: left in place for the provider +// to reject (parity with the typed helper). +func stripUnsupportedFieldsFromRawBody(jsonBody []byte, provider schemas.ModelProvider, model string) ([]byte, error) { + if len(jsonBody) == 0 { + return jsonBody, nil + } + features, ok := ProviderFeatures[provider] + if !ok { + return jsonBody, nil + } + + // Fall back to body-embedded model when caller didn't pass one. + if model == "" { + if modelResult := providerUtils.GetJSONField(jsonBody, "model"); modelResult.Exists() { + model = modelResult.String() + } + } + + var err error + + // speed — provider AND model gate + if providerUtils.JSONFieldExists(jsonBody, "speed") { + if !features.FastMode || !SupportsFastMode(model) { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "speed") + if err != nil { + return nil, fmt.Errorf("strip raw speed: %w", err) + } + } + } + + // inference_geo + if !features.InferenceGeo && providerUtils.JSONFieldExists(jsonBody, "inference_geo") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "inference_geo") + if err != nil { + return nil, fmt.Errorf("strip raw inference_geo: %w", err) + } + } + + // mcp_servers + if !features.MCP && providerUtils.JSONFieldExists(jsonBody, "mcp_servers") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "mcp_servers") + if err != nil { + return nil, fmt.Errorf("strip raw mcp_servers: %w", err) + } + } + + // container — two variants: bare string id (ContainerBasic), or object + // {id, skills[]} where skills require Skills flag. + // Distinguishes three states: no skills field (bare form), skills:[] (empty + // array — caller oversight, silently strip), skills:[…] (non-empty — caller + // explicitly wants skills). Mirrors the typed path's hybrid decision. + if containerResult := providerUtils.GetJSONField(jsonBody, "container"); containerResult.Exists() { + hasSkillsField, hasNonEmptySkills := false, false + if containerResult.IsObject() { + if skills := containerResult.Get("skills"); skills.Exists() { + hasSkillsField = true + if skills.IsArray() && len(skills.Array()) > 0 { + hasNonEmptySkills = true + } + } + } + // Always strip the skills key on Skills=false providers — critical on + // the raw path since bytes flow directly to the provider and an + // explicit empty array would still be rejected as unknown field. + if !features.Skills && hasSkillsField { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "container.skills") + if err != nil { + return nil, fmt.Errorf("strip raw container.skills: %w", err) + } + } + drop := false + switch { + case hasNonEmptySkills: + drop = !features.Skills + default: + drop = !features.ContainerBasic + } + if drop { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "container") + if err != nil { + return nil, fmt.Errorf("strip raw container: %w", err) + } + } + } + + // output_config.task_budget + if !features.TaskBudgets && providerUtils.JSONFieldExists(jsonBody, "output_config.task_budget") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "output_config.task_budget") + if err != nil { + return nil, fmt.Errorf("strip raw output_config.task_budget: %w", err) + } + // Drop an empty parent so we don't serialize output_config:{} (matches + // typed-path behavior at lines 129-134). + if oc := providerUtils.GetJSONField(jsonBody, "output_config"); oc.IsObject() && len(oc.Map()) == 0 { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "output_config") + if err != nil { + return nil, fmt.Errorf("strip raw output_config: %w", err) + } + } + } + + // top-level cache_control.scope + if !features.PromptCachingScope && providerUtils.JSONFieldExists(jsonBody, "cache_control.scope") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "cache_control.scope") + if err != nil { + return nil, fmt.Errorf("strip raw cache_control.scope: %w", err) + } + // Drop an empty parent so we don't serialize cache_control:{} (matches + // typed-path behavior at lines 147-153). + if cc := providerUtils.GetJSONField(jsonBody, "cache_control"); cc.IsObject() && len(cc.Map()) == 0 { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "cache_control") + if err != nil { + return nil, fmt.Errorf("strip raw cache_control: %w", err) + } + } + } + + // context_management.edits[] — gate per edit.type. + if editsResult := providerUtils.GetJSONField(jsonBody, "context_management.edits"); editsResult.Exists() && editsResult.IsArray() { + edits := editsResult.Array() + // Collect indices to drop (iterate forwards, delete in reverse). + dropIndices := []int{} + for i, edit := range edits { + editType := edit.Get("type").String() + keep := true + switch editType { + case string(ContextManagementEditTypeCompact): + keep = features.Compaction + case string(ContextManagementEditTypeClearToolUses), string(ContextManagementEditTypeClearThinking): + keep = features.ContextEditing + } + if !keep { + dropIndices = append(dropIndices, i) + } + } + if len(dropIndices) == len(edits) && len(edits) > 0 { + // All edits unsupported — drop the whole context_management. + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "context_management") + if err != nil { + return nil, fmt.Errorf("strip raw context_management: %w", err) + } + } else { + for i := len(dropIndices) - 1; i >= 0; i-- { + path := fmt.Sprintf("context_management.edits.%d", dropIndices[i]) + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, path) + if err != nil { + return nil, fmt.Errorf("strip raw context_management.edits[%d]: %w", dropIndices[i], err) + } + } + } + } + + // per-tool flags + nested scope + if toolsResult := providerUtils.GetJSONField(jsonBody, "tools"); toolsResult.Exists() && toolsResult.IsArray() { + for i := range toolsResult.Array() { + base := fmt.Sprintf("tools.%d", i) + if !features.AdvancedToolUse { + if providerUtils.JSONFieldExists(jsonBody, base+".defer_loading") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".defer_loading") + if err != nil { + return nil, fmt.Errorf("strip raw %s.defer_loading: %w", base, err) + } + } + if providerUtils.JSONFieldExists(jsonBody, base+".allowed_callers") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".allowed_callers") + if err != nil { + return nil, fmt.Errorf("strip raw %s.allowed_callers: %w", base, err) + } + } + } + if !features.InputExamples && providerUtils.JSONFieldExists(jsonBody, base+".input_examples") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".input_examples") + if err != nil { + return nil, fmt.Errorf("strip raw %s.input_examples: %w", base, err) + } + } + if !features.EagerInputStreaming && providerUtils.JSONFieldExists(jsonBody, base+".eager_input_streaming") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".eager_input_streaming") + if err != nil { + return nil, fmt.Errorf("strip raw %s.eager_input_streaming: %w", base, err) + } + } + if !features.StructuredOutputs && providerUtils.JSONFieldExists(jsonBody, base+".strict") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".strict") + if err != nil { + return nil, fmt.Errorf("strip raw %s.strict: %w", base, err) + } + } + if !features.PromptCachingScope && providerUtils.JSONFieldExists(jsonBody, base+".cache_control.scope") { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".cache_control.scope") + if err != nil { + return nil, fmt.Errorf("strip raw %s.cache_control.scope: %w", base, err) + } + // Drop the parent if cache_control is now an empty object, so + // we don't forward a malformed `cache_control: {}` marker. + if ccResult := providerUtils.GetJSONField(jsonBody, base+".cache_control"); ccResult.Exists() && ccResult.IsObject() && len(ccResult.Map()) == 0 { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, base+".cache_control") + if err != nil { + return nil, fmt.Errorf("strip raw %s.cache_control empty parent: %w", base, err) + } + } + } + } + } + + // Nested scope on system blocks (system can be a string OR array of blocks). + if !features.PromptCachingScope { + if systemResult := providerUtils.GetJSONField(jsonBody, "system"); systemResult.Exists() && systemResult.IsArray() { + for i := range systemResult.Array() { + path := fmt.Sprintf("system.%d.cache_control.scope", i) + if providerUtils.JSONFieldExists(jsonBody, path) { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, path) + if err != nil { + return nil, fmt.Errorf("strip raw system[%d].cache_control.scope: %w", i, err) + } + parentPath := fmt.Sprintf("system.%d.cache_control", i) + if ccResult := providerUtils.GetJSONField(jsonBody, parentPath); ccResult.Exists() && ccResult.IsObject() && len(ccResult.Map()) == 0 { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, parentPath) + if err != nil { + return nil, fmt.Errorf("strip raw system[%d].cache_control empty parent: %w", i, err) + } + } + } + } + } + // Nested scope on messages[].content[] blocks. + if messagesResult := providerUtils.GetJSONField(jsonBody, "messages"); messagesResult.Exists() && messagesResult.IsArray() { + messages := messagesResult.Array() + for mi := range messages { + contentResult := providerUtils.GetJSONField(jsonBody, fmt.Sprintf("messages.%d.content", mi)) + if !contentResult.Exists() || !contentResult.IsArray() { + continue + } + for ci := range contentResult.Array() { + path := fmt.Sprintf("messages.%d.content.%d.cache_control.scope", mi, ci) + if providerUtils.JSONFieldExists(jsonBody, path) { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, path) + if err != nil { + return nil, fmt.Errorf("strip raw messages[%d].content[%d].cache_control.scope: %w", mi, ci, err) + } + parentPath := fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci) + if ccResult := providerUtils.GetJSONField(jsonBody, parentPath); ccResult.Exists() && ccResult.IsObject() && len(ccResult.Map()) == 0 { + jsonBody, err = providerUtils.DeleteJSONField(jsonBody, parentPath) + if err != nil { + return nil, fmt.Errorf("strip raw messages[%d].content[%d].cache_control empty parent: %w", mi, ci, err) + } + } + } + } + } + } + } + + return jsonBody, nil +} + +// IsOpus47 returns true if the model is Claude Opus 4.7 or a later generation where: +// - Extended thinking (budget_tokens) is removed — only adaptive thinking is supported. +// - temperature, top_p, and top_k are not supported (setting them returns a 400). +func IsOpus47(model string) bool { + model = strings.ToLower(model) + if !strings.Contains(model, "opus") { + return false + } + return strings.Contains(model, "4-7") || strings.Contains(model, "4.7") +} + // SupportsNativeEffort returns true if the model supports Anthropic's native output_config.effort parameter. // Currently supported on Claude Opus 4.5 and Opus 4.6. func SupportsNativeEffort(model string) bool { @@ -101,12 +614,33 @@ func SupportsNativeEffort(model string) bool { strings.Contains(model, "4-6") || strings.Contains(model, "4.6") } +// SupportsFastMode returns true if the model supports speed:"fast" (research +// preview). Per Anthropic's fast-mode docs, only Opus 4.6 supports it; +// requests carrying speed:"fast" to any other model are rejected with 400. +// Beta header: fast-mode-2026-02-01. +// +// Source: https://platform.claude.com/docs/en/build-with-claude/fast-mode +func SupportsFastMode(model string) bool { + model = strings.ToLower(model) + if !strings.Contains(model, "opus") { + return false + } + return strings.Contains(model, "4-6") || strings.Contains(model, "4.6") +} + // SupportsAdaptiveThinking returns true if the model supports thinking.type: "adaptive". -// Currently only supported on Claude Opus 4.6. +// Currently supported on Claude Opus 4.6, Claude Sonnet 4.6, and Claude Opus 4.7+. +// On Opus 4.7+ adaptive is the only thinking-on mode; on Opus 4.6 and Sonnet 4.6 it +// coexists with the deprecated budget_tokens-based extended thinking. func SupportsAdaptiveThinking(model string) bool { + if IsOpus47(model) { + return true + } model = strings.ToLower(model) - return strings.Contains(model, "opus") && - (strings.Contains(model, "4-6") || strings.Contains(model, "4.6")) + if !strings.Contains(model, "4-6") && !strings.Contains(model, "4.6") { + return false + } + return strings.Contains(model, "opus") || strings.Contains(model, "sonnet") } // MapBifrostEffortToAnthropic maps a Bifrost effort level to an Anthropic effort level. @@ -118,15 +652,6 @@ func MapBifrostEffortToAnthropic(effort string) string { return effort } -// MapAnthropicEffortToBifrost maps an Anthropic effort level to a Bifrost effort level. -// Anthropic supports "max" (Opus 4.6+) which is not in Bifrost's enum; it maps to "high". -func MapAnthropicEffortToBifrost(effort string) string { - if effort == "max" { - return "high" - } - return effort -} - // setEffortOnOutputConfig merges the effort value into the request's OutputConfig, // preserving any existing Format field (used for structured outputs). func setEffortOnOutputConfig(req *AnthropicMessageRequest, effort string) { @@ -136,7 +661,7 @@ func setEffortOnOutputConfig(req *AnthropicMessageRequest, effort string) { req.OutputConfig.Effort = &effort } -func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, providerName schemas.ModelProvider, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, isStreaming bool, excludeFields []string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader in completeRequest/ // setAnthropicRequestBody — skip all body building here (matches CheckContextAndGetRequestBody). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -156,7 +681,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi _, model := schemas.ParseModelString(modelStr, schemas.Anthropic) jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", model) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -168,36 +693,56 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi } jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", defaultMaxTokens) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Strip auto-injectable server-side tools to prevent conflicts with API auto-injection jsonBody, err = StripAutoInjectableTools(jsonBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) + } + // Sanitize raw-body fields the target provider does not support. + // Behavioural parity with StripUnsupportedAnthropicFields on the typed path. + // Feature gating keyed to schemas.Anthropic (not providerName) to match + // the typed path below which also hardcodes schemas.Anthropic — ensures + // custom Anthropic aliases get identical feature lookup in both modes. + jsonBody, err = stripUnsupportedFieldsFromRawBody(jsonBody, schemas.Anthropic, "") + if err != nil { + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) + } + // Auto-inject matching anthropic-beta headers for fields the sanitizer + // preserved (speed, task_budget, cache_control.scope, input_examples, + // defer_loading, allowed_callers, eager_input_streaming, mcp_servers, + // structured outputs, etc). Without this, raw-body callers who supply + // gated fields but not headers would 400 upstream. Single source of + // truth: probe-unmarshal into the typed struct and reuse the typed + // path's header walker. + var probe AnthropicMessageRequest + if err := schemas.Unmarshal(jsonBody, &probe); err == nil { + AddMissingBetaHeadersToContext(ctx, &probe, schemas.Anthropic) } // Remove excluded fields for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Convert request to Anthropic format reqBody, convErr := ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Anthropic) if isStreaming { @@ -206,7 +751,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -215,14 +760,14 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // Use MergeExtraParamsIntoJSON which preserves key order jsonBody, err = providerUtils.MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Remove excluded fields after merging (using sjson to preserve order) for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else if len(excludeFields) > 0 { @@ -230,7 +775,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi for _, field := range excludeFields { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, field) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -239,7 +784,7 @@ func getRequestBodyForResponses(ctx *schemas.BifrostContext, request *schemas.Bi // delete fallbacks field jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } return jsonBody, nil @@ -272,15 +817,48 @@ func AddMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM headers = appendUniqueHeader(headers, AnthropicStructuredOutputsBetaHeader) } } - // Check for advanced-tool-use features + // Check for advanced-tool-use features. defer_loading and + // allowed_callers are only available as part of the bundle + // header; input_examples additionally has a standalone header + // (tool-examples-2025-10-29) used on Bedrock where the bundle is + // not accepted. if tool.DeferLoading != nil && *tool.DeferLoading { - headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + if !hasProvider || features.AdvancedToolUse { + headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + } } if len(tool.InputExamples) > 0 { - headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + if !hasProvider || features.AdvancedToolUse { + // Bundle header covers input_examples transitively. + headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + } else if features.InputExamples { + // Narrow standalone header (e.g. Bedrock). + headers = appendUniqueHeader(headers, AnthropicToolExamplesBetaHeader) + } } if len(tool.AllowedCallers) > 0 { - headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + if !hasProvider || features.AdvancedToolUse { + headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + } + } + // input_examples has both bundle coverage AND a standalone header. + // Prefer the bundle header when the provider accepts the bundle + // (covers input_examples transitively); fall back to the narrow + // standalone header (Bedrock) when only InputExamples is set. + if len(tool.InputExamples) > 0 { + if !hasProvider || features.AdvancedToolUse { + headers = appendUniqueHeader(headers, AnthropicAdvancedToolUseBetaHeader) + } else if features.InputExamples { + headers = appendUniqueHeader(headers, AnthropicToolExamplesBetaHeader) + } + } + // Check for fine-grained tool streaming (eager_input_streaming). + // Beta fine-grained-tool-streaming-2025-05-14 — required for + // input_json_delta streaming on custom tools. + if tool.EagerInputStreaming != nil && *tool.EagerInputStreaming { + if !hasProvider || features.EagerInputStreaming { + headers = appendUniqueHeader(headers, AnthropicEagerInputStreamingBetaHeader) + } } // Check for cache control with scope if !hasCachingScope && tool.CacheControl != nil && tool.CacheControl.Scope != nil { @@ -291,6 +869,14 @@ func AddMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM } } } + // Check for cache control with scope at the top level of the request + // (mirrors the tool/system/message checks below). + if !hasCachingScope && req.CacheControl != nil && req.CacheControl.Scope != nil { + if !hasProvider || features.PromptCachingScope { + headers = appendUniqueHeader(headers, AnthropicPromptCachingScopeBetaHeader) + hasCachingScope = true + } + } // Check for compaction if req.ContextManagement != nil { for _, edit := range req.ContextManagement.Edits { @@ -318,12 +904,20 @@ func AddMissingBetaHeadersToContext(ctx *schemas.BifrostContext, req *AnthropicM headers = appendUniqueHeader(headers, AnthropicInterleavedThinkingBetaHeader) } } - // Check for fast mode + // Check for fast mode. Only add the beta header when both the provider + // supports fast mode AND the model does (Opus 4.6 only per + // SupportsFastMode); otherwise sending the header guarantees a 400. if req.Speed != nil && *req.Speed == "fast" { - if !hasProvider || features.FastMode { + if (!hasProvider || features.FastMode) && SupportsFastMode(req.Model) { headers = appendUniqueHeader(headers, AnthropicFastModeBetaHeader) } } + // Check for task budget + if req.OutputConfig != nil && req.OutputConfig.TaskBudget != nil { + if !hasProvider || features.TaskBudgets { + headers = appendUniqueHeader(headers, AnthropicTaskBudgetsBetaHeader) + } + } // Check for output format (structured outputs) if req.OutputFormat != nil { if !hasProvider || features.StructuredOutputs { @@ -400,11 +994,14 @@ var betaHeaderPrefixKnown = []string{ "context-management-", "files-api-", AnthropicAdvancedToolUseBetaHeaderPrefix, + AnthropicToolExamplesBetaHeaderPrefix, AnthropicInterleavedThinkingBetaHeaderPrefix, AnthropicSkillsBetaHeaderPrefix, AnthropicContext1MBetaHeaderPrefix, AnthropicFastModeBetaHeaderPrefix, AnthropicRedactThinkingBetaHeaderPrefix, + AnthropicTaskBudgetsBetaHeaderPrefix, + AnthropicEagerInputStreamingBetaHeaderPrefix, } // betaHeaderPrefixExists checks if any header in existing shares a known prefix with newHeader. @@ -594,11 +1191,14 @@ var betaHeaderPrefixToFeature = map[string]func(ProviderFeatureSupport) bool{ "context-management-": func(f ProviderFeatureSupport) bool { return f.ContextEditing }, "files-api-": func(f ProviderFeatureSupport) bool { return f.FilesAPI }, AnthropicAdvancedToolUseBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.AdvancedToolUse }, + AnthropicToolExamplesBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.InputExamples }, AnthropicInterleavedThinkingBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.InterleavedThinking }, AnthropicSkillsBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.Skills }, AnthropicContext1MBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.Context1M }, AnthropicFastModeBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.FastMode }, AnthropicRedactThinkingBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.RedactThinking }, + AnthropicTaskBudgetsBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.TaskBudgets }, + AnthropicEagerInputStreamingBetaHeaderPrefix: func(f ProviderFeatureSupport) bool { return f.EagerInputStreaming }, } // MergeBetaHeaders collects anthropic-beta values from provider ExtraHeaders and @@ -1087,7 +1687,7 @@ func ConvertToAnthropicImageBlock(block schemas.ChatContentBlock) AnthropicConte imageBlock := AnthropicContentBlock{ Type: AnthropicContentBlockTypeImage, CacheControl: block.CacheControl, - Source: &AnthropicSource{}, + Source: &AnthropicBlockSource{SourceObj: &AnthropicSource{}}, } if block.ImageURLStruct == nil { @@ -1098,8 +1698,8 @@ func ConvertToAnthropicImageBlock(block schemas.ChatContentBlock) AnthropicConte sanitizedURL, err := schemas.SanitizeImageURL(block.ImageURLStruct.URL) if err != nil { // Best-effort: treat as a regular URL without sanitization - imageBlock.Source.Type = "url" - imageBlock.Source.URL = &block.ImageURLStruct.URL + imageBlock.Source.SourceObj.Type = "url" + imageBlock.Source.SourceObj.URL = &block.ImageURLStruct.URL return imageBlock } urlTypeInfo := schemas.ExtractURLTypeInfo(sanitizedURL) @@ -1120,18 +1720,18 @@ func ConvertToAnthropicImageBlock(block schemas.ChatContentBlock) AnthropicConte // Convert to Anthropic source format if formattedImgContent.Type == schemas.ImageContentTypeURL { - imageBlock.Source.Type = "url" - imageBlock.Source.URL = &formattedImgContent.URL + imageBlock.Source.SourceObj.Type = "url" + imageBlock.Source.SourceObj.URL = &formattedImgContent.URL } else { if formattedImgContent.MediaType != "" { - imageBlock.Source.MediaType = &formattedImgContent.MediaType + imageBlock.Source.SourceObj.MediaType = &formattedImgContent.MediaType } - imageBlock.Source.Type = "base64" + imageBlock.Source.SourceObj.Type = "base64" // Use the base64 data without the data URL prefix if urlTypeInfo.DataURLWithoutPrefix != nil { - imageBlock.Source.Data = urlTypeInfo.DataURLWithoutPrefix + imageBlock.Source.SourceObj.Data = urlTypeInfo.DataURLWithoutPrefix } else { - imageBlock.Source.Data = &formattedImgContent.URL + imageBlock.Source.SourceObj.Data = &formattedImgContent.URL } } @@ -1143,7 +1743,7 @@ func ConvertToAnthropicDocumentBlock(block schemas.ChatContentBlock) AnthropicCo documentBlock := AnthropicContentBlock{ Type: AnthropicContentBlockTypeDocument, CacheControl: block.CacheControl, - Source: &AnthropicSource{}, + Source: &AnthropicBlockSource{SourceObj: &AnthropicSource{}}, } if block.Citations != nil { @@ -1163,8 +1763,8 @@ func ConvertToAnthropicDocumentBlock(block schemas.ChatContentBlock) AnthropicCo // Handle file URL if file.FileURL != nil && *file.FileURL != "" { - documentBlock.Source.Type = "url" - documentBlock.Source.URL = file.FileURL + documentBlock.Source.SourceObj.Type = "url" + documentBlock.Source.SourceObj.URL = file.FileURL return documentBlock } @@ -1174,8 +1774,8 @@ func ConvertToAnthropicDocumentBlock(block schemas.ChatContentBlock) AnthropicCo // Check if it's plain text based on file type if file.FileType != nil && (*file.FileType == "text/plain" || *file.FileType == "txt") { - documentBlock.Source.Type = "text" - documentBlock.Source.Data = &fileData + documentBlock.Source.SourceObj.Type = "text" + documentBlock.Source.SourceObj.Data = &fileData return documentBlock } @@ -1184,30 +1784,30 @@ func ConvertToAnthropicDocumentBlock(block schemas.ChatContentBlock) AnthropicCo if urlTypeInfo.DataURLWithoutPrefix != nil { // It's a data URL, extract the base64 content - documentBlock.Source.Type = "base64" - documentBlock.Source.Data = urlTypeInfo.DataURLWithoutPrefix + documentBlock.Source.SourceObj.Type = "base64" + documentBlock.Source.SourceObj.Data = urlTypeInfo.DataURLWithoutPrefix // Set media type from data URL or file type if urlTypeInfo.MediaType != nil { - documentBlock.Source.MediaType = urlTypeInfo.MediaType + documentBlock.Source.SourceObj.MediaType = urlTypeInfo.MediaType } else if file.FileType != nil { - documentBlock.Source.MediaType = file.FileType + documentBlock.Source.SourceObj.MediaType = file.FileType } return documentBlock } } // Default to base64 for binary files - documentBlock.Source.Type = "base64" - documentBlock.Source.Data = &fileData + documentBlock.Source.SourceObj.Type = "base64" + documentBlock.Source.SourceObj.Data = &fileData // Set media type if file.FileType != nil { - documentBlock.Source.MediaType = file.FileType + documentBlock.Source.SourceObj.MediaType = file.FileType } else { // Default to PDF if not specified mediaType := "application/pdf" - documentBlock.Source.MediaType = &mediaType + documentBlock.Source.SourceObj.MediaType = &mediaType } return documentBlock } @@ -1220,7 +1820,7 @@ func ConvertResponsesFileBlockToAnthropic(fileBlock *schemas.ResponsesInputMessa documentBlock := AnthropicContentBlock{ Type: AnthropicContentBlockTypeDocument, CacheControl: cacheControl, - Source: &AnthropicSource{}, + Source: &AnthropicBlockSource{SourceObj: &AnthropicSource{}}, } if citations != nil { @@ -1242,9 +1842,9 @@ func ConvertResponsesFileBlockToAnthropic(fileBlock *schemas.ResponsesInputMessa // Check if it's plain text based on file type if fileBlock.FileType != nil && (*fileBlock.FileType == "text/plain" || *fileBlock.FileType == "txt") { - documentBlock.Source.Type = "text" - documentBlock.Source.Data = &fileData - documentBlock.Source.MediaType = schemas.Ptr("text/plain") + documentBlock.Source.SourceObj.Type = "text" + documentBlock.Source.SourceObj.Data = &fileData + documentBlock.Source.SourceObj.MediaType = schemas.Ptr("text/plain") return documentBlock } @@ -1254,38 +1854,38 @@ func ConvertResponsesFileBlockToAnthropic(fileBlock *schemas.ResponsesInputMessa if urlTypeInfo.DataURLWithoutPrefix != nil { // It's a data URL, extract the base64 content - documentBlock.Source.Type = "base64" - documentBlock.Source.Data = urlTypeInfo.DataURLWithoutPrefix + documentBlock.Source.SourceObj.Type = "base64" + documentBlock.Source.SourceObj.Data = urlTypeInfo.DataURLWithoutPrefix // Set media type from data URL or file type if urlTypeInfo.MediaType != nil { - documentBlock.Source.MediaType = urlTypeInfo.MediaType + documentBlock.Source.SourceObj.MediaType = urlTypeInfo.MediaType } else if fileBlock.FileType != nil { - documentBlock.Source.MediaType = fileBlock.FileType + documentBlock.Source.SourceObj.MediaType = fileBlock.FileType } return documentBlock } } // Default to base64 for binary files (raw base64 without prefix) - documentBlock.Source.Type = "base64" - documentBlock.Source.Data = &fileData + documentBlock.Source.SourceObj.Type = "base64" + documentBlock.Source.SourceObj.Data = &fileData // Set media type if fileBlock.FileType != nil { - documentBlock.Source.MediaType = fileBlock.FileType + documentBlock.Source.SourceObj.MediaType = fileBlock.FileType } else { // Default to PDF if not specified mediaType := "application/pdf" - documentBlock.Source.MediaType = &mediaType + documentBlock.Source.SourceObj.MediaType = &mediaType } return documentBlock } // Handle file URL if fileBlock.FileURL != nil && *fileBlock.FileURL != "" { - documentBlock.Source.Type = "url" - documentBlock.Source.URL = fileBlock.FileURL + documentBlock.Source.SourceObj.Type = "url" + documentBlock.Source.SourceObj.URL = fileBlock.FileURL return documentBlock } @@ -1302,22 +1902,24 @@ func (block AnthropicContentBlock) ToBifrostContentImageBlock() schemas.ChatCont } func getImageURLFromBlock(block AnthropicContentBlock) string { - if block.Source == nil { + // Image blocks always carry object-form sources (never string form). + if block.Source == nil || block.Source.SourceObj == nil { return "" } + src := block.Source.SourceObj // Handle base64 data - convert to data URL - if block.Source.Data != nil { + if src.Data != nil { mime := "image/png" - if block.Source.MediaType != nil && *block.Source.MediaType != "" { - mime = *block.Source.MediaType + if src.MediaType != nil && *src.MediaType != "" { + mime = *src.MediaType } - return "data:" + mime + ";base64," + *block.Source.Data + return "data:" + mime + ";base64," + *src.Data } // Handle regular URLs - if block.Source.URL != nil { - return *block.Source.URL + if src.URL != nil { + return *src.URL } return "" diff --git a/core/providers/anthropic/utils_test.go b/core/providers/anthropic/utils_test.go index e9d7172e42..0a0160316e 100644 --- a/core/providers/anthropic/utils_test.go +++ b/core/providers/anthropic/utils_test.go @@ -772,19 +772,32 @@ func TestAddMissingBetaHeadersToContext_PerProvider(t *testing.T) { }, unexpectHeaders: []string{AnthropicInterleavedThinkingBetaHeader}, }, - // Fast mode tests + // Fast mode tests — fast mode is Opus 4.6 only (research preview), + // so tests must set Model to exercise the path. Non-Opus-4.6 models + // are model-gated out regardless of provider flag. { name: "Anthropic gets fast mode header", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ + Model: "claude-opus-4-6", Speed: schemas.Ptr("fast"), }, expectHeaders: []string{AnthropicFastModeBetaHeader}, }, + { + name: "Anthropic skips fast mode header on non-Opus-4.6 model", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + Model: "claude-sonnet-4-6", + Speed: schemas.Ptr("fast"), + }, + unexpectHeaders: []string{AnthropicFastModeBetaHeader}, + }, { name: "Bedrock skips fast mode header", provider: schemas.Bedrock, req: &AnthropicMessageRequest{ + Model: "claude-opus-4-6", // fast mode is model-gated; set a supporting model so the test actually exercises provider suppression Speed: schemas.Ptr("fast"), }, unexpectHeaders: []string{AnthropicFastModeBetaHeader}, @@ -793,10 +806,63 @@ func TestAddMissingBetaHeadersToContext_PerProvider(t *testing.T) { name: "Azure skips fast mode header", provider: schemas.Azure, req: &AnthropicMessageRequest{ + Model: "claude-opus-4-6", // fast mode is model-gated; set a supporting model so the test actually exercises provider suppression Speed: schemas.Ptr("fast"), }, unexpectHeaders: []string{AnthropicFastModeBetaHeader}, }, + // Fine-grained tool streaming (eager_input_streaming) — per Table 20: + // GA on Anthropic / Bedrock / Vertex, Beta on Azure. All four should + // auto-inject fine-grained-tool-streaming-2025-05-14 when a tool has + // eager_input_streaming: true. + { + name: "Anthropic gets eager_input_streaming header", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, + }, + expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, + }, + { + name: "Bedrock gets eager_input_streaming header", + provider: schemas.Bedrock, + req: &AnthropicMessageRequest{ + Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, + }, + expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, + }, + { + name: "Vertex gets eager_input_streaming header", + provider: schemas.Vertex, + req: &AnthropicMessageRequest{ + Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, + }, + expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, + }, + { + name: "Azure gets eager_input_streaming header", + provider: schemas.Azure, + req: &AnthropicMessageRequest{ + Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, + }, + expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, + }, + { + name: "eager_input_streaming header absent when flag is false", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(false)}}, + }, + unexpectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, + }, + { + name: "eager_input_streaming header absent when unset", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + Tools: []AnthropicTool{{Name: "t1"}}, + }, + unexpectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, + }, } for _, tt := range tests { @@ -998,6 +1064,7 @@ func TestFilterBetaHeadersForProvider(t *testing.T) { AnthropicContextManagementBetaHeader, AnthropicInterleavedThinkingBetaHeader, AnthropicContext1MBetaHeader, + AnthropicEagerInputStreamingBetaHeader, } result := FilterBetaHeadersForProvider(supported, schemas.Vertex) if len(result) != len(supported) { @@ -1049,6 +1116,7 @@ func TestFilterBetaHeadersForProvider(t *testing.T) { AnthropicSkillsBetaHeader, AnthropicContext1MBetaHeader, AnthropicRedactThinkingBetaHeader, + AnthropicEagerInputStreamingBetaHeader, } result := FilterBetaHeadersForProvider(supported, schemas.Azure) if len(result) != len(supported) { @@ -1064,6 +1132,7 @@ func TestFilterBetaHeadersForProvider(t *testing.T) { AnthropicContextManagementBetaHeader, AnthropicInterleavedThinkingBetaHeader, AnthropicContext1MBetaHeader, + AnthropicEagerInputStreamingBetaHeader, } result := FilterBetaHeadersForProvider(supported, schemas.Bedrock) if len(result) != len(supported) { @@ -1184,6 +1253,239 @@ func TestFilterBetaHeadersForProvider(t *testing.T) { } } +func TestStripUnsupportedFieldsFromRawBody(t *testing.T) { + t.Run("bedrock_strips_new_request_level_fields", func(t *testing.T) { + // Raw body with every new typed field. Targeting Bedrock: speed (no FastMode), + // inference_geo (no InferenceGeo), mcp_servers (no MCP), container.skills + // (no Skills), top-level cache_control.scope (no PromptCachingScope), + // output_config.task_budget (no TaskBudgets). All should be stripped. + input := []byte(`{ + "model":"claude-opus-4-6", + "speed":"fast", + "inference_geo":"us-east-1", + "mcp_servers":[{"type":"url","url":"https://example.com","name":"x"}], + "container":{"id":"c-1","skills":[{"skill_id":"s","type":"anthropic"}]}, + "cache_control":{"type":"ephemeral","ttl":"5m","scope":"user"}, + "output_config":{"task_budget":{"type":"tokens","total":20000}} + }`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, path := range []string{"speed", "inference_geo", "mcp_servers", "container", "cache_control.scope", "output_config.task_budget"} { + if providerUtils.JSONFieldExists(result, path) { + t.Errorf("expected %q to be stripped for Bedrock, got: %s", path, string(result)) + } + } + // Confirm non-scope cache_control fields are retained. + if !providerUtils.JSONFieldExists(result, "cache_control.ttl") { + t.Errorf("expected cache_control.ttl to survive, got: %s", string(result)) + } + }) + + t.Run("vertex_strips_mcp_strict_and_input_examples_via_feature_check", func(t *testing.T) { + // Vertex: no MCP, no InputExamples, no StructuredOutputs. + // tool.strict stripped; tool.input_examples stripped; mcp_servers stripped. + // tool.cache_control.scope stripped (Vertex has no PromptCachingScope). + input := []byte(`{ + "model":"claude-sonnet-4-6", + "mcp_servers":[{"type":"url","url":"u","name":"n"}], + "tools":[{"name":"t1","strict":true,"input_examples":[{"input":{"a":1}}],"cache_control":{"type":"ephemeral","scope":"user"}}] + }`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Vertex, "claude-sonnet-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, path := range []string{"mcp_servers", "tools.0.strict", "tools.0.input_examples", "tools.0.cache_control.scope"} { + if providerUtils.JSONFieldExists(result, path) { + t.Errorf("expected %q to be stripped for Vertex, got: %s", path, string(result)) + } + } + if !providerUtils.JSONFieldExists(result, "tools.0.name") { + t.Errorf("expected tool name to survive") + } + }) + + t.Run("bedrock_keeps_input_examples_via_standalone_flag", func(t *testing.T) { + // Bedrock has InputExamples=true via tool-examples-2025-10-29 but + // AdvancedToolUse=false. input_examples should be KEPT; defer_loading + // and allowed_callers (bundle-only) should be STRIPPED. + input := []byte(`{ + "model":"claude-opus-4-6", + "tools":[{"name":"t1","input_examples":[{"input":{"a":1}}],"defer_loading":true,"allowed_callers":["direct"]}] + }`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !providerUtils.JSONFieldExists(result, "tools.0.input_examples") { + t.Errorf("expected tools[0].input_examples to survive on Bedrock, got: %s", string(result)) + } + for _, path := range []string{"tools.0.defer_loading", "tools.0.allowed_callers"} { + if providerUtils.JSONFieldExists(result, path) { + t.Errorf("expected %q to be stripped for Bedrock (AdvancedToolUse bundle unsupported), got: %s", path, string(result)) + } + } + }) + + t.Run("speed_stripped_on_non_opus_46_even_on_anthropic", func(t *testing.T) { + // Model gate: fast-mode is Opus 4.6 only per docs. Even on Anthropic + // direct where FastMode=true, targeting a different model must strip. + input := []byte(`{"model":"claude-sonnet-4-6","speed":"fast"}`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-sonnet-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if providerUtils.JSONFieldExists(result, "speed") { + t.Errorf("expected speed stripped for non-Opus-4.6 model on Anthropic, got: %s", string(result)) + } + }) + + t.Run("anthropic_direct_is_noop", func(t *testing.T) { + // Anthropic supports everything — body should survive untouched. + input := []byte(`{"model":"claude-opus-4-6","speed":"fast","mcp_servers":[{"type":"url","url":"u","name":"n"}],"container":{"id":"c"},"tools":[{"name":"t","defer_loading":true,"input_examples":[{"input":{"a":1}}]}]}`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, path := range []string{"speed", "mcp_servers", "container", "tools.0.defer_loading", "tools.0.input_examples"} { + if !providerUtils.JSONFieldExists(result, path) { + t.Errorf("expected %q preserved on Anthropic direct, got: %s", path, string(result)) + } + } + }) + + t.Run("nested_scope_stripped_on_messages_and_system", func(t *testing.T) { + // Nested scope on system blocks and message blocks must also be stripped + // when the provider lacks PromptCachingScope. + input := []byte(`{ + "model":"claude-opus-4-6", + "system":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"user"}}], + "messages":[{"role":"user","content":[{"type":"text","text":"q","cache_control":{"type":"ephemeral","scope":"global"}}]}] + }`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, path := range []string{"system.0.cache_control.scope", "messages.0.content.0.cache_control.scope"} { + if providerUtils.JSONFieldExists(result, path) { + t.Errorf("expected nested %q stripped, got: %s", path, string(result)) + } + } + }) + + t.Run("unknown_provider_is_safe_noop", func(t *testing.T) { + input := []byte(`{"model":"claude-opus-4-6","speed":"fast"}`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.ModelProvider("custom"), "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !providerUtils.JSONFieldExists(result, "speed") { + t.Errorf("expected speed preserved for unknown provider (safe default), got: %s", string(result)) + } + }) + + t.Run("container_empty_skills_stripped_but_container_preserved", func(t *testing.T) { + // Skills=false provider (Bedrock), ContainerBasic=true. + // skills:[] is a caller oversight — strip the empty key, preserve container. + input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[]}}`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if providerUtils.JSONFieldExists(result, "container.skills") { + t.Errorf("expected empty container.skills stripped on Skills=false provider, got: %s", string(result)) + } + if !providerUtils.JSONFieldExists(result, "container.id") { + t.Errorf("expected container.id preserved (bare form still valid), got: %s", string(result)) + } + }) + + t.Run("container_nonempty_skills_drops_whole_container", func(t *testing.T) { + // Non-empty skills signals caller intent; provider doesn't support — drop container. + input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[{"skill_id":"s","type":"anthropic"}]}}`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if providerUtils.JSONFieldExists(result, "container") { + t.Errorf("expected whole container dropped for non-empty skills on Skills=false, got: %s", string(result)) + } + }) + + t.Run("container_empty_skills_on_skills_capable_provider_preserved", func(t *testing.T) { + // On Anthropic direct (Skills=true), the empty skills array must be preserved + // as-is — our strip logic only fires when !features.Skills. + input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[]}}`) + result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-opus-4-6") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !providerUtils.JSONFieldExists(result, "container.skills") { + t.Errorf("expected container.skills preserved on Skills=true provider, got: %s", string(result)) + } + }) +} + +// TestStripUnsupportedAnthropicFields_ContainerSkillsGating mirrors the raw-path +// tests above on the typed path — ensures the typed sanitizer treats explicit +// empty skills arrays as a stripable (not drop-triggering) signal. +func TestStripUnsupportedAnthropicFields_ContainerSkillsGating(t *testing.T) { + t.Run("empty_skills_on_skills_false_provider_strips_skills_keeps_container", func(t *testing.T) { + req := &AnthropicMessageRequest{ + Model: "claude-opus-4-6", + Container: &AnthropicContainer{ + ContainerObject: &AnthropicContainerObject{ + ID: schemas.Ptr("c-1"), + Skills: []AnthropicContainerSkill{}, // explicit empty + }, + }, + } + stripUnsupportedAnthropicFields(req, schemas.Bedrock, "claude-opus-4-6") + if req.Container == nil { + t.Fatalf("expected container preserved (bare form valid with empty skills), got nil") + } + if req.Container.ContainerObject == nil || req.Container.ContainerObject.Skills != nil { + t.Errorf("expected skills cleared on Skills=false, got %v", req.Container.ContainerObject) + } + }) + + t.Run("nonempty_skills_on_skills_false_provider_drops_container", func(t *testing.T) { + req := &AnthropicMessageRequest{ + Model: "claude-opus-4-6", + Container: &AnthropicContainer{ + ContainerObject: &AnthropicContainerObject{ + ID: schemas.Ptr("c-1"), + Skills: []AnthropicContainerSkill{{SkillID: "s", Type: "anthropic"}}, + }, + }, + } + stripUnsupportedAnthropicFields(req, schemas.Bedrock, "claude-opus-4-6") + if req.Container != nil { + t.Errorf("expected whole container dropped for non-empty skills on Skills=false, got %v", req.Container) + } + }) + + t.Run("empty_skills_on_skills_true_provider_preserved", func(t *testing.T) { + req := &AnthropicMessageRequest{ + Model: "claude-opus-4-6", + Container: &AnthropicContainer{ + ContainerObject: &AnthropicContainerObject{ + ID: schemas.Ptr("c-1"), + Skills: []AnthropicContainerSkill{}, + }, + }, + } + stripUnsupportedAnthropicFields(req, schemas.Anthropic, "claude-opus-4-6") + if req.Container == nil || req.Container.ContainerObject == nil { + t.Fatalf("expected container preserved on Skills=true provider, got %v", req.Container) + } + if req.Container.ContainerObject.Skills == nil { + t.Errorf("expected empty skills preserved on Skills=true provider (not nilled)") + } + }) +} + func TestStripAutoInjectableTools(t *testing.T) { t.Run("code_execution_without_web_search_preserved", func(t *testing.T) { // code_execution alone should NOT be stripped (no web_search/web_fetch to trigger auto-injection) @@ -1486,7 +1788,7 @@ func TestGetRequestBodyForResponses_RawBodyStripsFallbacks(t *testing.T) { RawRequestBody: rawBody, } - result, bifrostErr := getRequestBodyForResponses(ctx, request, schemas.Anthropic, false, nil) + result, bifrostErr := getRequestBodyForResponses(ctx, request, false, nil) if bifrostErr != nil { t.Fatalf("unexpected error: %v", bifrostErr) } @@ -1574,3 +1876,109 @@ func TestApplyMCPToolsetConfigToBifrostTool(t *testing.T) { applyMCPToolsetConfigToBifrostTool(&schemas.ResponsesTool{}, nil) }) } + +func TestSupportsAdaptiveThinking(t *testing.T) { + tests := []struct { + model string + expected bool + }{ + {"claude-opus-4-7-20260401", true}, + {"claude-opus-4.7-20260401", true}, + {"claude-opus-4-6-20250514", true}, + {"claude-opus-4.6-20250514", true}, + {"claude-sonnet-4-6-20250514", true}, + {"claude-sonnet-4.6-20250514", true}, + {"claude-opus-4-5-20241022", false}, + {"claude-sonnet-4-5-20241022", false}, + {"claude-haiku-4-6-20250514", false}, // haiku does not support adaptive + {"claude-haiku-4-7-20260401", false}, // haiku, not opus + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + got := SupportsAdaptiveThinking(tt.model) + if got != tt.expected { + t.Errorf("SupportsAdaptiveThinking(%q) = %v, want %v", tt.model, got, tt.expected) + } + }) + } +} + +func TestAddMissingBetaHeadersToContext_TaskBudgets(t *testing.T) { + tests := []struct { + name string + provider schemas.ModelProvider + req *AnthropicMessageRequest + expectHeaders []string + unexpectHeaders []string + }{ + { + name: "Anthropic gets task-budgets header when task_budget set", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + OutputConfig: &AnthropicOutputConfig{ + TaskBudget: &AnthropicTaskBudget{Type: "tokens", Total: 50000}, + }, + }, + expectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, + }, + { + name: "Vertex does not get task-budgets header when task_budget set", + provider: schemas.Vertex, + req: &AnthropicMessageRequest{ + OutputConfig: &AnthropicOutputConfig{ + TaskBudget: &AnthropicTaskBudget{Type: "tokens", Total: 50000}, + }, + }, + unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, + }, + { + name: "no task-budgets header when task_budget is nil", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{ + OutputConfig: &AnthropicOutputConfig{}, + }, + unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, + }, + { + name: "no task-budgets header when output_config is nil", + provider: schemas.Anthropic, + req: &AnthropicMessageRequest{}, + unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := schemas.NewBifrostContext(nil, time.Time{}) + AddMissingBetaHeadersToContext(ctx, tt.req, tt.provider) + + var headers []string + if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { + headers = extraHeaders[AnthropicBetaHeader] + } + + for _, expected := range tt.expectHeaders { + found := false + for _, h := range headers { + if h == expected { + found = true + break + } + } + if !found { + t.Errorf("expected header %q not found in %v", expected, headers) + } + } + + for _, unexpected := range tt.unexpectHeaders { + for _, h := range headers { + if h == unexpected { + t.Errorf("unexpected header %q found in %v", unexpected, headers) + } + } + } + }) + } +} diff --git a/core/providers/anthropic/validate_chat_tools_test.go b/core/providers/anthropic/validate_chat_tools_test.go new file mode 100644 index 0000000000..d9f0c8a2df --- /dev/null +++ b/core/providers/anthropic/validate_chat_tools_test.go @@ -0,0 +1,138 @@ +package anthropic + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestValidateChatToolsForProvider locks in the partition: +// function/custom tools always survive; server tools survive only when the +// target provider's ProviderFeatures flag is true for that tool type. +func TestValidateChatToolsForProvider(t *testing.T) { + fnTool := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{Name: "get_weather"}, + } + serverTool := func(tpe, name string) schemas.ChatTool { + return schemas.ChatTool{Type: schemas.ChatToolType(tpe), Name: name} + } + + cases := []struct { + name string + provider schemas.ModelProvider + input []schemas.ChatTool + wantKeep int + wantDropped []string + assertNotes string + }{ + { + name: "function tools always survive on any provider", + provider: schemas.Bedrock, + input: []schemas.ChatTool{fnTool, fnTool}, + wantKeep: 2, + }, + { + name: "bedrock drops web_search", + provider: schemas.Bedrock, + input: []schemas.ChatTool{serverTool("web_search_20260209", "web_search")}, + wantKeep: 0, + wantDropped: []string{"web_search_20260209"}, + assertNotes: "Bedrock has WebSearch=false per Table 20 (AWS user guide beta-header list + Anthropic overview)", + }, + { + name: "bedrock drops web_fetch + code_execution + mcp_toolset", + provider: schemas.Bedrock, + input: []schemas.ChatTool{ + serverTool("web_fetch_20260309", "web_fetch"), + serverTool("code_execution_20250825", "code_execution"), + serverTool("mcp_toolset", "notion"), + }, + wantKeep: 0, + wantDropped: []string{"web_fetch_20260309", "code_execution_20250825", "mcp_toolset"}, + }, + { + name: "bedrock keeps computer/bash/memory/text_editor/tool_search", + provider: schemas.Bedrock, + input: []schemas.ChatTool{ + serverTool("computer_20251124", "computer"), + serverTool("bash_20250124", "bash"), + serverTool("memory_20250818", "memory"), + serverTool("text_editor_20250728", "str_replace_based_edit_tool"), + serverTool("tool_search_tool_bm25", "tool_search_tool_bm25"), + }, + wantKeep: 5, + }, + { + name: "bedrock partial drop mixes function + server tools", + provider: schemas.Bedrock, + input: []schemas.ChatTool{ + fnTool, + serverTool("web_search_20260209", "web_search"), + serverTool("bash_20250124", "bash"), + }, + wantKeep: 2, // fnTool + bash + wantDropped: []string{"web_search_20260209"}, + }, + { + name: "vertex drops web_fetch", + provider: schemas.Vertex, + input: []schemas.ChatTool{serverTool("web_fetch_20260309", "web_fetch")}, + wantKeep: 0, + wantDropped: []string{"web_fetch_20260309"}, + assertNotes: "Vertex has WebFetch=false per Table 20", + }, + { + name: "vertex drops mcp_toolset", + provider: schemas.Vertex, + input: []schemas.ChatTool{serverTool("mcp_toolset", "notion")}, + wantKeep: 0, + wantDropped: []string{"mcp_toolset"}, + assertNotes: "Vertex has MCP=false per MCP-excl (explicit exclusion in Anthropic docs)", + }, + { + name: "anthropic keeps everything", + provider: schemas.Anthropic, + input: []schemas.ChatTool{ + serverTool("web_search_20260209", "web_search"), + serverTool("web_fetch_20260309", "web_fetch"), + serverTool("code_execution_20250825", "code_execution"), + serverTool("mcp_toolset", "x"), + serverTool("computer_20251124", "computer"), + }, + wantKeep: 5, + }, + { + name: "unknown provider keeps everything (forward-compat)", + provider: schemas.ModelProvider("custom-new-provider"), + input: []schemas.ChatTool{serverTool("web_search_20260209", "web_search")}, + wantKeep: 1, + }, + { + name: "unknown tool type on known provider is kept (forward-compat)", + provider: schemas.Bedrock, + input: []schemas.ChatTool{serverTool("future_tool_20270101", "future")}, + wantKeep: 1, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + keep, dropped := ValidateChatToolsForProvider(tc.input, tc.provider) + if len(keep) != tc.wantKeep { + t.Errorf("keep count: got %d, want %d (%s)", len(keep), tc.wantKeep, tc.assertNotes) + } + if len(dropped) != len(tc.wantDropped) { + t.Errorf("dropped count: got %v, want %v", dropped, tc.wantDropped) + } + for i, d := range tc.wantDropped { + if i >= len(dropped) { + break + } + if dropped[i] != d { + t.Errorf("dropped[%d]: got %q, want %q", i, dropped[i], d) + } + } + }) + } +} diff --git a/core/providers/azure/azure.go b/core/providers/azure/azure.go index 9649caa662..323d13584e 100644 --- a/core/providers/azure/azure.go +++ b/core/providers/azure/azure.go @@ -100,7 +100,7 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -109,11 +109,11 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, Scopes: scopes, }) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("Azure access token is empty", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -138,16 +138,16 @@ func (provider *AzureProvider) getAzureAuthHeaders(ctx *schemas.BifrostContext, cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty"), schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", errors.New("token is empty")) } authHeader["Authorization"] = fmt.Sprintf("Bearer %s", token.Token) @@ -206,10 +206,8 @@ func (provider *AzureProvider) completeRequest( jsonData []byte, path string, key schemas.Key, - deployment string, model string, - requestType schemas.RequestType, -) ([]byte, string, time.Duration, map[string]string, *schemas.BifrostError) { +) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -222,7 +220,7 @@ func (provider *AzureProvider) completeRequest( }() var url string - isAnthropicModel := schemas.IsAnthropicModel(deployment) + isAnthropicModel := schemas.IsAnthropicModel(model) // Set any extra headers from network config. // For Anthropic models, exclude anthropic-beta — it is merged and filtered explicitly below. @@ -237,7 +235,7 @@ func (provider *AzureProvider) completeRequest( // Get authentication headers authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, isAnthropicModel) if bifrostErr != nil { - return nil, deployment, 0, nil, bifrostErr + return nil, 0, nil, bifrostErr } // Apply headers to request @@ -247,7 +245,7 @@ func (provider *AzureProvider) completeRequest( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, deployment, 0, nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, 0, nil, providerUtils.NewConfigurationError("endpoint not set") } if isAnthropicModel { @@ -282,7 +280,7 @@ func (provider *AzureProvider) completeRequest( latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp) defer wait() if bifrostErr != nil { - return nil, deployment, latency, nil, bifrostErr + return nil, latency, nil, bifrostErr } // Extract provider response headers before body is copied — do this before status check @@ -292,33 +290,25 @@ func (provider *AzureProvider) completeRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, deployment, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, requestType, provider.GetProviderKey(), model) + rawErrBody := append([]byte(nil), resp.Body()...) + return rawErrBody, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { - return nil, deployment, latency, providerResponseHeaders, decodeErr + return nil, latency, providerResponseHeaders, decodeErr } if isLargeResp { respOwned = false - return nil, deployment, latency, providerResponseHeaders, nil + return nil, latency, providerResponseHeaders, nil } - return body, deployment, latency, providerResponseHeaders, nil + return body, latency, providerResponseHeaders, nil } // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - // Validate Azure key configuration - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", schemas.Azure) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", schemas.Azure) - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -359,12 +349,12 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -379,9 +369,9 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert to Bifrost response - response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, key.BlacklistedModels, request.Unfiltered) + response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure) + return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil) } response.ExtraFields.Latency = latency.Milliseconds() @@ -415,35 +405,23 @@ func (provider *AzureProvider) ListModels(ctx *schemas.BifrostContext, keys []sc // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized OpenAI text converter (Azure is OpenAI-compatible) jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAITextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/completions", deployment), + fmt.Sprintf("openai/deployments/%s/completions", request.Model), key, - deployment, request.Model, - schemas.TextCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -457,10 +435,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostTextCompletionResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.TextCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -474,10 +448,6 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -498,21 +468,12 @@ func (provider *AzureProvider) TextCompletion(ctx *schemas.BifrostContext, key s // It formats the request, sends it to Azure, and processes the response. // Returns a channel of BifrostStreamChunk objects or an error if the request fails. func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -520,11 +481,6 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, return nil, err } - customPostResponseConverter := func(response *schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, @@ -538,7 +494,7 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, nil, postHookRunner, nil, - customPostResponseConverter, + nil, provider.logger, ) } @@ -547,26 +503,16 @@ func (provider *AzureProvider) TextCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { reqBody, err := anthropic.ToAnthropicChatRequest(ctx, request) if err != nil { return nil, err } if reqBody != nil { - reqBody.Model = deployment // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } @@ -574,27 +520,24 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } else { return openai.ToOpenAIChatRequest(ctx, request), nil } - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { - path = fmt.Sprintf("openai/deployments/%s/chat/completions", deployment) + path = fmt.Sprintf("openai/deployments/%s/chat/completions", request.Model) } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ChatCompletionRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -608,10 +551,6 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -622,7 +561,7 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -637,12 +576,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ChatCompletionRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -662,22 +597,8 @@ func (provider *AzureProvider) ChatCompletion(ctx *schemas.BifrostContext, key s // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -694,14 +615,12 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, return nil, err } if reqBody != nil { - reqBody.Model = deployment reqBody.Stream = schemas.Ptr(true) // Add provider-aware beta headers for Azure anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -719,13 +638,8 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -736,7 +650,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Use shared streaming logic from OpenAI return openai.HandleOpenAIChatCompletionStreaming( @@ -754,7 +668,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, nil, nil, nil, - postResponseConverter, + nil, provider.logger, ) } @@ -764,51 +678,36 @@ func (provider *AzureProvider) ChatCompletionStream(ctx *schemas.BifrostContext, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - var jsonData []byte var bifrostErr *schemas.BifrostError - if schemas.IsAnthropicModel(deployment) { - jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), false) + if schemas.IsAnthropicModel(request.Model) { + jsonData, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false) } else { jsonData, bifrostErr = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { reqBody := openai.ToOpenAIResponsesRequest(request) - if reqBody != nil { - reqBody.Model = deployment - } return reqBody, nil - }, - provider.GetProviderKey()) + }) } if bifrostErr != nil { return nil, bifrostErr } var path string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { path = "anthropic/v1/messages" } else { path = "openai/v1/responses" } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, path, key, - deployment, request.Model, - schemas.ResponsesRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -822,10 +721,6 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -836,7 +731,7 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema var rawRequest interface{} var rawResponse interface{} - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -851,12 +746,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema } } - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.RequestType = schemas.ResponsesRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -873,22 +764,8 @@ func (provider *AzureProvider) Responses(ctx *schemas.BifrostContext, key schema // ResponsesStream performs a streaming responses request to Azure's API. func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelDeployment = deployment - return response - } - var url string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { authHeader, err := provider.getAzureAuthHeaders(ctx, key, true) if err != nil { return nil, err @@ -896,7 +773,7 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post authHeader["anthropic-version"] = AzureAnthropicAPIVersionDefault url = fmt.Sprintf("%s/anthropic/v1/messages", key.AzureKeyConfig.Endpoint.GetValue()) - jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, provider.GetProviderKey(), true) + jsonData, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true) if bifrostErr != nil { return nil, bifrostErr } @@ -914,13 +791,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) } else { authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) @@ -929,11 +801,6 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post } url = fmt.Sprintf("%s/openai/v1/responses?api-version=preview", key.AzureKeyConfig.Endpoint.GetValue()) - postRequestConverter := func(req *openai.OpenAIResponsesRequest) *openai.OpenAIResponsesRequest { - req.Model = deployment - return req - } - // Use shared streaming logic from OpenAI return openai.HandleOpenAIResponsesStreaming( ctx, @@ -948,8 +815,8 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -959,35 +826,23 @@ func (provider *AzureProvider) ResponsesStream(ctx *schemas.BifrostContext, post // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Use centralized converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return openai.ToOpenAIEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, deployment, latency, providerResponseHeaders, err := provider.completeRequest( + responseBody, latency, providerResponseHeaders, err := provider.completeRequest( ctx, jsonData, - fmt.Sprintf("openai/deployments/%s/embeddings", deployment), + fmt.Sprintf("openai/deployments/%s/embeddings", request.Model), key, - deployment, request.Model, - schemas.EmbeddingRequest, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1001,10 +856,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1019,12 +870,8 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.Provider = provider.GetProviderKey() response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - response.ExtraFields.RequestType = schemas.EmbeddingRequest // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1041,15 +888,6 @@ func (provider *AzureProvider) Embedding(ctx *schemas.BifrostContext, key schema // Speech is not supported by the Azure provider. func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1057,10 +895,10 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAISpeechRequest( ctx, @@ -1080,9 +918,6 @@ func (provider *AzureProvider) Speech(ctx *schemas.BifrostContext, key schemas.K return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1099,15 +934,6 @@ func (provider *AzureProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, // SpeechStream handles streaming for speech synthesis with Azure. // Azure sends raw binary audio bytes in SSE format, unlike OpenAI which sends JSON. func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - // Get Azure authentication headers authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1118,7 +944,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1158,11 +984,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo reqBody := openai.ToOpenAISpeechRequest(request) if reqBody != nil { reqBody.StreamFormat = schemas.Ptr("sse") - reqBody.Model = deployment // Replace model with deployment } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1186,9 +1010,9 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(requestErr, fasthttp.ErrTimeout) || errors.Is(requestErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr, provider.GetProviderKey()), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, requestErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1197,7 +1021,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, openai.ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Create response channel @@ -1207,11 +1031,12 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1312,11 +1137,6 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo var bifrostErr schemas.BifrostError if errParseErr := sonic.Unmarshal(audioData, &bifrostErr); errParseErr == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -1338,12 +1158,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Set extra fields for the response response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -1372,7 +1188,7 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // a fake "done" response with truncated audio. ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, provider.GetProviderKey(), request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1385,12 +1201,8 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo finalResponse := schemas.BifrostSpeechStreamResponse{ Type: schemas.SpeechStreamResponseTypeDone, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1413,21 +1225,12 @@ func (provider *AzureProvider) SpeechStream(ctx *schemas.BifrostContext, postHoo // Transcription is not supported by the Azure provider. func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, err := provider.getModelDeployment(key, request.Model) - if err != nil { - return nil, err - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) } - url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", key.AzureKeyConfig.Endpoint.GetValue(), request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAITranscriptionRequest( ctx, @@ -1446,9 +1249,6 @@ func (provider *AzureProvider) Transcription(ctx *schemas.BifrostContext, key sc return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1462,16 +1262,6 @@ func (provider *AzureProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Returns a BifrostResponse containing the bifrost response or an error if the request fails. func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1479,13 +1269,13 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } response, err := openai.HandleOpenAIImageGenerationRequest( ctx, provider.client, - fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()), + fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()), request, key, provider.networkConfig.ExtraHeaders, @@ -1498,9 +1288,6 @@ func (provider *AzureProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } @@ -1513,18 +1300,6 @@ func (provider *AzureProvider) ImageGenerationStream( key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - // - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionDefault) @@ -1532,17 +1307,10 @@ func (provider *AzureProvider) ImageGenerationStream( endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1563,7 +1331,7 @@ func (provider *AzureProvider) ImageGenerationStream( postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1571,16 +1339,6 @@ func (provider *AzureProvider) ImageGenerationStream( // ImageEdit performs an image edit request to Azure's API. func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1588,10 +1346,10 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) response, err := openai.HandleOpenAIImageEditRequest( ctx, provider.client, @@ -1608,24 +1366,11 @@ func (provider *AzureProvider) ImageEdit(ctx *schemas.BifrostContext, key schema return nil, err } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, err } // ImageEditStream performs a streaming image edit request to Azure's API. func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - // Validate api key configs - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment := key.AzureKeyConfig.Deployments[request.Model] - if deployment == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), provider.GetProviderKey()) - } - apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil || apiVersion.GetValue() == "" { apiVersion = schemas.NewEnvVar(AzureAPIVersionImageEditDefault) @@ -1633,17 +1378,10 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } - postResponseConverter := func(resp *schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse { - if resp != nil { - resp.ExtraFields.ModelDeployment = deployment - } - return resp - } - - url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, deployment, apiVersion.GetValue()) + url := fmt.Sprintf("%s/openai/deployments/%s/images/edits?api-version=%s", endpoint, request.Model, apiVersion.GetValue()) authHeader, err := provider.getAzureAuthHeaders(ctx, key, false) if err != nil { @@ -1664,7 +1402,7 @@ func (provider *AzureProvider) ImageEditStream(ctx *schemas.BifrostContext, post postHookRunner, nil, nil, - postResponseConverter, + nil, provider.logger, ) @@ -1678,30 +1416,19 @@ func (provider *AzureProvider) ImageVariation(ctx *schemas.BifrostContext, key s // VideoGeneration creates a video using Azure's OpenAI-compatible Sora API. // This delegates to the OpenAI handler with Azure-specific URL and authentication. func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - - deployment, bifrostErr := provider.getModelDeployment(key, request.Model) - if bifrostErr != nil { - return nil, bifrostErr - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL for OpenAI-compatible video generation endpoint url := fmt.Sprintf("%s/openai/v1/videos", endpoint) - requestCopy := *request - requestCopy.Model = deployment response, bifrostErr := openai.HandleOpenAIVideoGenerationRequest( ctx, provider.client, url, - &requestCopy, + request, key, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), @@ -1713,27 +1440,20 @@ func (provider *AzureProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment - return response, nil } // VideoRetrieve retrieves the status of a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } authHeaders, bifrostErr := provider.getAzureAuthHeaders(ctx, key, false) @@ -1759,20 +1479,16 @@ func (provider *AzureProvider) VideoRetrieve(ctx *schemas.BifrostContext, key sc // VideoDownload downloads video content from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Create request @@ -1808,13 +1524,12 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -1830,9 +1545,7 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc Content: append([]byte(nil), body...), ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -1841,20 +1554,16 @@ func (provider *AzureProvider) VideoDownload(ctx *schemas.BifrostContext, key sc // VideoDelete deletes a video from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", providerName) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1881,13 +1590,9 @@ func (provider *AzureProvider) VideoDelete(ctx *schemas.BifrostContext, key sche // VideoList lists videos from Azure's OpenAI-compatible API. func (provider *AzureProvider) VideoList(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfig(key); err != nil { - return nil, err - } - endpoint := key.AzureKeyConfig.Endpoint.GetValue() if endpoint == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("endpoint not set") } // Build Azure URL @@ -1917,64 +1622,14 @@ func (provider *AzureProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.K return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey()) } -// validateKeyConfig validates the key configuration. -// It checks if the key config is set, the endpoint is set, and the deployments are set. -// Returns an error if any of the checks fail. -func (provider *AzureProvider) validateKeyConfig(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments == nil { - return providerUtils.NewConfigurationError("deployments not set", provider.GetProviderKey()) - } - - return nil -} - -// validateKeyConfigForFiles validates key config for file/batch APIs, which only -// require a configured Azure endpoint (no per-model deployments needed). -func (provider *AzureProvider) validateKeyConfigForFiles(key schemas.Key) *schemas.BifrostError { - if key.AzureKeyConfig == nil { - return providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - return nil -} - -func (provider *AzureProvider) getModelDeployment(key schemas.Key, model string) (string, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return "", providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Deployments != nil { - if deployment, ok := key.AzureKeyConfig.Deployments[model]; ok { - return deployment, nil - } - } - return "", providerUtils.NewConfigurationError(fmt.Sprintf("deployment not found for model %s", model), provider.GetProviderKey()) -} - // FileUpload uploads a file to Azure OpenAI. func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Get API version @@ -1989,7 +1644,7 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add file field @@ -1999,14 +1654,14 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -2043,13 +1698,12 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp openai.OpenAIFileResponse @@ -2060,17 +1714,15 @@ func (provider *AzureProvider) FileUpload(ctx *schemas.BifrostContext, key schem return nil, bifrostErr } - return openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // FileList lists files from all provided Azure keys and aggregates results. // FileList lists files using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file list operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2079,7 +1731,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2090,18 +1742,9 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2149,13 +1792,12 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIFileListResponse @@ -2190,9 +1832,7 @@ func (provider *AzureProvider) FileList(ctx *schemas.BifrostContext, keys []sche Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2207,7 +1847,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2215,11 +1855,6 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2262,8 +1897,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2275,7 +1909,7 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2301,14 +1935,12 @@ func (provider *AzureProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [] // FileDelete deletes a file from Azure OpenAI by trying each key until successful. func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file delete operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2316,11 +1948,6 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2363,8 +1990,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2380,9 +2006,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2392,7 +2016,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2415,9 +2039,7 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2437,24 +2059,17 @@ func (provider *AzureProvider) FileDelete(ctx *schemas.BifrostContext, keys []sc // FileContent downloads file content from Azure OpenAI by trying each key until found. func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for file content operation") } var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2496,8 +2111,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2509,7 +2123,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2529,9 +2143,7 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2542,12 +2154,6 @@ func (provider *AzureProvider) FileContent(ctx *schemas.BifrostContext, keys []s // BatchCreate creates a new batch job on Azure OpenAI. // Azure Batch API uses the same format as OpenAI but with Azure-specific URL patterns. func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -2555,12 +2161,11 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Convert inline requests to JSONL format jsonlData, err := openai.ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" uploadResp, bifrostErr := provider.FileUpload(ctx, key, &schemas.BifrostFileUploadRequest{ - Provider: schemas.Azure, File: jsonlData, Filename: "batch_requests.jsonl", Purpose: "batch", @@ -2574,7 +2179,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for Azure batch API", nil) } // Get API version @@ -2621,7 +2226,7 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -2634,13 +2239,12 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } var openAIResp openai.OpenAIBatchResponse @@ -2651,25 +2255,24 @@ func (provider *AzureProvider) BatchCreate(ctx *schemas.BifrostContext, key sche return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs from all provided Azure keys and aggregates results. // BatchList lists batch jobs using serial pagination across keys. // Exhausts all pages from one key before moving to the next. func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch list operation") } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2680,18 +2283,9 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - // Validate key config - if err := provider.validateKeyConfigForFiles(key); err != nil { - return nil, err - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2737,13 +2331,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, openai.ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, openai.ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp openai.OpenAIBatchListResponse @@ -2756,7 +2349,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -2769,9 +2362,7 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2783,14 +2374,12 @@ func (provider *AzureProvider) BatchList(ctx *schemas.BifrostContext, keys []sch // BatchRetrieve retrieves a specific batch job from Azure OpenAI by trying each key until found. func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch retrieve operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2798,11 +2387,6 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2845,8 +2429,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2858,7 +2441,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2876,8 +2459,7 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -2886,14 +2468,12 @@ func (provider *AzureProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys [ // BatchCancel cancels a batch job on Azure OpenAI by trying each key until successful. func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation", providerName) + return nil, providerUtils.NewConfigurationError("no Azure keys available for batch cancel operation") } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -2901,11 +2481,6 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s var lastErr *schemas.BifrostError for _, key := range keys { - if err := provider.validateKeyConfigForFiles(key); err != nil { - lastErr = err - continue - } - // Get API version apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { @@ -2948,8 +2523,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = openai.ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = openai.ParseOpenAIError(resp) wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) @@ -2961,7 +2535,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s wait() fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -2986,9 +2560,7 @@ func (provider *AzureProvider) BatchCancel(ctx *schemas.BifrostContext, keys []s CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3022,8 +2594,6 @@ func (provider *AzureProvider) BatchDelete(ctx *schemas.BifrostContext, keys []s // BatchResults retrieves batch results from Azure OpenAI by trying each key until successful. // For Azure (like OpenAI), batch results are obtained by downloading the output_file_id. func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3034,7 +2604,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file content (using all keys) @@ -3063,9 +2633,7 @@ func (provider *AzureProvider) BatchResults(ctx *schemas.BifrostContext, keys [] BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } @@ -3132,14 +2700,6 @@ func (provider *AzureProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) fasthttpReq := fasthttp.AcquireRequest() @@ -3176,7 +2736,7 @@ func (provider *AzureProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -3192,9 +2752,6 @@ func (provider *AzureProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3212,14 +2769,6 @@ func (provider *AzureProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.AzureKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("azure key config not set", provider.GetProviderKey()) - } - - if key.AzureKeyConfig.Endpoint.GetValue() == "" { - return nil, providerUtils.NewConfigurationError("endpoint not set", provider.GetProviderKey()) - } - url := provider.buildPassthroughURL(key, req.Path, req.RawQuery) fasthttpReq := fasthttp.AcquireRequest() @@ -3266,9 +2815,9 @@ func (provider *AzureProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -3276,21 +2825,13 @@ func (provider *AzureProvider) PassthroughStream( rawBodyStream := resp.BodyStream() if rawBodyStream == nil { providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError( - "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + return nil, providerUtils.NewBifrostOperationError("provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body")) } bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx)) stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -3298,11 +2839,12 @@ func (provider *AzureProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -3351,7 +2893,7 @@ func (provider *AzureProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/azure/azure_test.go b/core/providers/azure/azure_test.go index b1080b7db3..32ec72040b 100644 --- a/core/providers/azure/azure_test.go +++ b/core/providers/azure/azure_test.go @@ -26,12 +26,12 @@ func TestAzure(t *testing.T) { testConfig := llmtests.ComprehensiveTestConfig{ Provider: schemas.Azure, - ChatModel: "gpt-4o-backup", - PromptCachingModel: "gpt-4o-backup", + ChatModel: "gpt-4o", + PromptCachingModel: "gpt-4o", VisionModel: "gpt-4o", ChatAudioModel: "gpt-4o-mini-audio-preview", Fallbacks: []schemas.Fallback{ - {Provider: schemas.Azure, Model: "gpt-4o-backup"}, + {Provider: schemas.Azure, Model: "gpt-4o"}, }, TextModel: "", // Azure doesn't support text completion in newer models EmbeddingModel: "text-embedding-ada-002", @@ -60,7 +60,7 @@ func TestAzure(t *testing.T) { Embedding: true, ListModels: true, Reasoning: true, - ChatAudio: true, + ChatAudio: false, Transcription: false, // Disabled for azure because of 3 calls/minute quota TranscriptionStream: false, // Not properly supported yet by Azure SpeechSynthesis: false, // Disabled for azure because of 3 calls/minute quota @@ -78,8 +78,10 @@ func TestAzure(t *testing.T) { VideoRemix: false, VideoList: false, VideoDelete: false, - InterleavedThinking: true, - PassthroughAPI: true, + InterleavedThinking: true, + PassthroughAPI: true, + EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (Beta on Azure Foundry) + ServerToolsViaOpenAIEndpoint: true, // web_search / web_fetch / code_execution on Azure per Table 20 }, DisableParallelFor: []string{"Transcription"}, // Azure Whisper has 3 calls/minute quota } diff --git a/core/providers/azure/files.go b/core/providers/azure/files.go index 4c7ce174f8..d008b146de 100644 --- a/core/providers/azure/files.go +++ b/core/providers/azure/files.go @@ -24,7 +24,7 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" { cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue()) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err) } scopes := getAzureScopes(key.AzureKeyConfig.Scopes) @@ -33,11 +33,11 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R Scopes: scopes, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to get Azure access token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("failed to get Azure access token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("Azure access token is empty", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("azure access token is empty", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -68,16 +68,16 @@ func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.R cred, err := provider.getOrCreateDefaultAzureCredential() if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err) } token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes}) if err != nil { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err, schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err) } if token.Token == "" { - return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty"), schemas.Azure) + return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty")) } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token)) @@ -110,9 +110,7 @@ func (r *AzureFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mod StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/azure/models.go b/core/providers/azure/models.go index d5ff81229a..5daca3836d 100644 --- a/core/providers/azure/models.go +++ b/core/providers/azure/models.go @@ -1,65 +1,13 @@ package azure import ( - "slices" + "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// findMatchingAllowedModel finds a matching item in a slice, considering both -// exact match and base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact match - if slices.Contains(slice, value) { - return value - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Return the item from slice (not value) to use the actual name from allowedModels - for _, item := range slice { - if schemas.SameBaseModel(item, value) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and base model matches (ignoring version suffixes). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check exact match first (by alias/key) - if deployment, ok := deployments[modelID]; ok { - return deployment, modelID - } - - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == modelID { - return depValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - for aliasKey, deploymentValue := range deployments { - // Check if modelID's base matches deploymentValue's base - if schemas.SameBaseModel(deploymentValue, modelID) { - return deploymentValue, aliasKey - } - // Also check if modelID's base matches alias's base (for cases where alias is used as deployment) - if schemas.SameBaseModel(aliasKey, modelID) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -68,111 +16,36 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - modelID := model.ID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering base model matches (ignoring version suffixes) - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ID, modelID, deploymentAlias, matchedAllowedModel) { - continue - } - - modelEntry := schemas.Model{ - ID: string(schemas.Azure) + "/" + modelID, - Created: schemas.Ptr(model.CreatedAt), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Azure) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true - } else { - includedModels[modelID] = true - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Azure, + MatchFns: providerUtils.DefaultMatchFns(), } - - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[alias] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[alias] = true - } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Azure) + "/" + result.ResolvedID, + Created: schemas.Ptr(model.CreatedAt), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Azure) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/azure/utils.go b/core/providers/azure/utils.go index 49d1db8de3..20f216c19e 100644 --- a/core/providers/azure/utils.go +++ b/core/providers/azure/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader — skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -27,24 +27,24 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } // Replace model with deployment jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Delete fallbacks field jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { @@ -52,10 +52,10 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s request.Model = deployment reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } if isStreaming { @@ -68,7 +68,7 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes, preserving field order jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err), providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err)) } } diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index 3cddc86340..6b7cf700f0 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -26,7 +26,6 @@ import ( "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/providers/anthropic" - "github.com/maximhq/bifrost/core/providers/cohere" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -222,7 +221,7 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -245,10 +244,10 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -349,7 +348,7 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros if key.Value.GetValue() != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -370,10 +369,10 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros } var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, latency, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } var opErr *net.OpError var dnsErr *net.DNSError @@ -420,15 +419,9 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros // makeStreamingRequest creates a streaming request to Bedrock's API. // It formats the request, sends it to Bedrock, and returns the response. // Returns the response body and an error if the request fails. -func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, "", providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - +func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContext, jsonData []byte, key schemas.Key, model string, action string) (*http.Response, *schemas.BifrostError) { // Format the path with proper model identifier for streaming - path, deployment := provider.getModelPath(action, model, key) + path := provider.getModelPath(action, model, key) region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { @@ -438,7 +431,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Create HTTP request for streaming req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s", region, path), bytes.NewReader(jsonData)) if reqErr != nil { - return nil, deployment, providerUtils.NewBifrostOperationError("error creating request", reqErr, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", reqErr) } // Set any extra headers from network config @@ -457,8 +450,8 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex } else { req.Header.Set("Accept", "application/vnd.amazon.eventstream") // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { - return nil, deployment, err + if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { + return nil, err } } @@ -466,7 +459,7 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex resp, respErr := provider.client.Do(req) if respErr != nil { if errors.Is(respErr, context.Canceled) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Type: schemas.Ptr(schemas.RequestCancelled), @@ -478,35 +471,29 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(respErr, &netErr) && netErr.Timeout() { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } if errors.Is(respErr, http.ErrHandlerTimeout) || errors.Is(respErr, context.DeadlineExceeded) { - return nil, deployment, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, respErr) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError var dnsErr *net.DNSError if errors.As(respErr, &opErr) || errors.As(respErr, &dnsErr) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderNetworkError, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderDoRequest, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } @@ -517,10 +504,10 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - return nil, deployment, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) + return nil, parseBedrockHTTPError(resp.StatusCode, resp.Header, body) } - return resp, deployment, nil + return resp, nil } // signAWSRequest signs an HTTP request using AWS Signature Version 4. @@ -537,7 +524,6 @@ func signAWSRequest( externalID *schemas.EnvVar, sessionName *schemas.EnvVar, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Set required headers before signing (only if not already set) if req.Header.Get("Content-Type") == "" { @@ -552,7 +538,7 @@ func signAWSRequest( if req.Body != nil { bodyBytes, err := io.ReadAll(req.Body) if err != nil { - return providerUtils.NewBifrostOperationError("error reading request body", err, providerName) + return providerUtils.NewBifrostOperationError("error reading request body", err) } // Restore the body for subsequent reads req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -594,11 +580,10 @@ func signAWSRequest( ) } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } if roleARN != nil && roleARN.GetValue() != "" { - extID := "" if externalID != nil { extID = externalID.GetValue() @@ -653,12 +638,12 @@ func signAWSRequest( // Get credentials creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } // Sign the request with AWS Signature V4 if err := signer.SignHTTP(ctx, creds, req, bodyHash, service, region, time.Now()); err != nil { - return providerUtils.NewBifrostOperationError("failed to sign request", err, providerName) + return providerUtils.NewBifrostOperationError("failed to sign request", err) } return nil @@ -668,13 +653,7 @@ func signAWSRequest( // It retrieves all foundation models available in Amazon Bedrock for a specific key. func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - config := key.BedrockKeyConfig - region := DefaultBedrockRegion if config.Region != nil && config.Region.GetValue() != "" { region = config.Region.GetValue() @@ -721,7 +700,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, err } } @@ -744,10 +723,10 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Check for timeout first using net.Error before checking net.OpError var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } if errors.Is(err, http.ErrHandlerTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -795,9 +774,9 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Convert to Bifrost response - response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, key.BlacklistedModels, request.Unfiltered) + response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil) } response.ExtraFields.Latency = time.Since(startTime).Milliseconds() @@ -838,24 +817,17 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) body, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -867,29 +839,25 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key // Handle model-specific response conversion var bifrostResponse *schemas.BifrostTextCompletionResponse switch { - case schemas.IsAnthropicModel(deployment): + case schemas.IsAnthropicModel(request.Model): var response BedrockAnthropicTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing anthropic response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() - case schemas.IsMistralModel(deployment): + case schemas.IsMistralModel(request.Model): var response BedrockMistralTextResponse if err := sonic.Unmarshal(body, &response); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing mistral response", err) } bifrostResponse = response.ToBifrostTextCompletionResponse() default: - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model)) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -902,7 +870,7 @@ func (provider *BedrockProvider) TextCompletion(ctx *schemas.BifrostContext, key if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing raw response", err) } bifrostResponse.ExtraFields.RawResponse = rawResponse } @@ -920,22 +888,17 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTextCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "invoke-with-response-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -949,11 +912,12 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -999,14 +963,9 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1037,15 +996,10 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: fmt.Sprintf("%s stream %s: %s", providerName, excType, errMsg), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1057,18 +1011,14 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex } if err := sonic.Unmarshal(message.Payload, &chunkPayload); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } // Create BifrostStreamChunk response containing the raw model-specific JSON chunk textResponse := &schemas.BifrostTextCompletionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - Latency: time.Since(startTime).Milliseconds(), + Latency: time.Since(startTime).Milliseconds(), // Pass the raw JSON string from the chunk bytes RawResponse: string(chunkPayload.Bytes), }, @@ -1090,26 +1040,19 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1126,13 +1069,13 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostChatResponse(ctx, request.Model) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Override finish reason for structured output @@ -1146,10 +1089,6 @@ func (provider *BedrockProvider) ChatCompletion(ctx *schemas.BifrostContext, key } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1176,21 +1115,17 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } - - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockChatCompletionRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1201,14 +1136,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds) - // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1264,7 +1199,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) + provider.logger.Warn("Error decoding EventStream message: %v", err) // Transport-level errors (stale/closed connection, unexpected EOF) are retryable. // Use IsBifrostError:false so the retry gate in executeRequestWithRetries can retry. if isStreamTransportError(err) { @@ -1274,14 +1209,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1297,7 +1227,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) + err := fmt.Errorf("stream %s: %s", excType, errMsg) // Retryable AWS exceptions must not set IsBifrostError:true — that would // bypass the retry gate in executeRequestWithRetries. Instead emit // IsBifrostError:false with the equivalent HTTP status code so the existing @@ -1309,14 +1239,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1326,7 +1251,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1405,12 +1330,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1427,11 +1348,6 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response, bifrostErr, _ := streamEvent.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1440,12 +1356,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex response.ID = id response.Model = request.Model response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1464,8 +1376,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex } // Send final response - response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) - response.ExtraFields.ModelDeployment = deployment + response := providerUtils.CreateBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, request.Model, 0) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonData) @@ -1486,26 +1397,19 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Use centralized Bedrock converter jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Format the path with proper model identifier - path, deployment := provider.getModelPath("converse", request.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -1522,22 +1426,18 @@ func (provider *BedrockProvider) Responses(ctx *schemas.BifrostContext, key sche // Parse the response using the new Bedrock type if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to parse bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert using the new response converter bifrostResponse, err := bedrockResponse.ToBifrostResponsesResponse(ctx) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err, providerName), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("failed to convert bedrock response", err), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = deployment + bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1565,20 +1465,17 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockResponsesRequest(ctx, request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - resp, deployment, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") + resp, bifrostErr := provider.makeStreamingRequest(ctx, jsonData, key, request.Model, "converse-stream") if bifrostErr != nil { return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1592,11 +1489,12 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1617,7 +1515,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po // Create stream state for stateful conversions streamState := acquireBedrockResponsesStreamState() - streamState.Model = &deployment + streamState.Model = &request.Model defer releaseBedrockResponsesStreamState(streamState) // Check for structured output mode - if set, we need to intercept tool calls @@ -1633,7 +1531,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po lastChunkTime := startTime decoder := eventstream.NewDecoder() payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer - for { // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { @@ -1651,12 +1548,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po finalResponses := FinalizeBedrockStream(streamState, chunkIndex, usage) for i, finalResponse := range finalResponses { finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1679,7 +1572,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po break } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error decoding %s EventStream message: %v", providerName, err) + provider.logger.Warn("Error decoding EventStream message: %v", err) // Transport-level errors (stale/closed connection, unexpected EOF) are retryable. // Use IsBifrostError:false so the retry gate in executeRequestWithRetries can retry. if isStreamTransportError(err) { @@ -1689,14 +1582,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1712,7 +1600,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po } } errMsg := string(message.Payload) - err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) + err := fmt.Errorf("stream %s: %s", excType, errMsg) // Retryable AWS exceptions must not set IsBifrostError:true — that would // bypass the retry gate in executeRequestWithRetries. Instead emit // IsBifrostError:false with the equivalent HTTP status code so the existing @@ -1724,14 +1612,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1741,7 +1624,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po var streamEvent BedrockStreamEvent if err := sonic.Unmarshal(message.Payload, &streamEvent); err != nil { provider.logger.Debug("Failed to parse JSON from event buffer: %v, data: %s", err, string(message.Payload)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -1797,12 +1680,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po SequenceNumber: chunkIndex, Delta: &content, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } chunkIndex++ @@ -1819,11 +1698,6 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po responses, bifrostErr, _ := streamEvent.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1831,12 +1705,8 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po for _, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ModelDeployment: deployment, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } chunkIndex++ lastChunkTime = time.Now() @@ -1862,15 +1732,10 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Determine model type modelType, err := DetermineEmbeddingModelType(request.Model) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } // Convert request and execute based on model type @@ -1879,7 +1744,6 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string var jsonData []byte switch modelType { @@ -1889,12 +1753,11 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockTitanEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) case "cohere": @@ -1903,16 +1766,15 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockCohereEmbeddingRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) + path = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) default: - return nil, providerUtils.NewConfigurationError("unsupported embedding model type", providerName) + return nil, providerUtils.NewConfigurationError("unsupported embedding model type") } if providerResponseHeaders != nil { @@ -1921,32 +1783,40 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche if bifrostError != nil { return nil, providerUtils.EnrichError(ctx, bifrostError, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - // Parse response based on model type var bifrostResponse *schemas.BifrostEmbeddingResponse switch modelType { case "titan": var titanResp BedrockTitanEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Titan embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = titanResp.ToBifrostEmbeddingResponse() bifrostResponse.Model = request.Model case "cohere": - var cohereResp cohere.CohereEmbeddingResponse + var cohereResp BedrockCohereEmbeddingResponse if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + } + converted, convErr := cohereResp.ToBifrostEmbeddingResponse() + if convErr != nil { + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse = cohereResp.ToBifrostEmbeddingResponse() + bifrostResponse = converted bifrostResponse.Model = request.Model + // For embeddings_by_type responses preserve the raw Bedrock payload so the + // invoke-endpoint converter can return all encoding variants verbatim, since + // the internal BifrostEmbeddingResponse only has float32 and string fields. + if cohereResp.ResponseType == "embeddings_by_type" { + var rawResponseData interface{} + if err := sonic.Unmarshal(rawResponse, &rawResponseData); err == nil { + bifrostResponse.ExtraFields.RawResponse = rawResponseData + } + } } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1972,26 +1842,16 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - - deployment := strings.TrimSpace(resolveBedrockDeployment(request.Model, key)) - if deployment == "" { - return nil, providerUtils.NewConfigurationError("bedrock rerank model is empty", providerName) - } - if !strings.HasPrefix(deployment, "arn:") { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", deployment), providerName) + if !strings.HasPrefix(request.Model, "arn:") { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("bedrock rerank requires an ARN model identifier; got %q", request.Model)) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToBedrockRerankRequest(request, deployment) + return ToBedrockRerankRequest(request, request.Model) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2015,10 +1875,6 @@ func (provider *BedrockProvider) Rerank(ctx *schemas.BifrostContext, key schemas bifrostResponse := response.ToBifrostRerankResponse(request.Documents, returnDocuments) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2053,37 +1909,34 @@ func (provider *BedrockProvider) TranscriptionStream(ctx *schemas.BifrostContext } // ImageGeneration generates images using Amazon Bedrock. -// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2. +// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2, and Stability AI models. // Returns a BifrostImageGenerationResponse containing the generated images and any error that occurred. func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var rawResponse []byte var jsonData []byte var bifrostError *schemas.BifrostError var latency time.Duration var providerResponseHeaders map[string]string var path string - var deployment string + + path = provider.getModelPath("invoke", request.Model, key) jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageGenerationRequest(request) + } return ToBedrockImageGenerationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } - path, deployment = provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError = provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2096,19 +1949,15 @@ func (provider *BedrockProvider) ImageGeneration(ctx *schemas.BifrostContext, ke var bifrostResponse *schemas.BifrostImageGenerationResponse var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image generation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2133,33 +1982,36 @@ func (provider *BedrockProvider) ImageGenerationStream(ctx *schemas.BifrostConte } // ImageEdit performs image editing using Amazon Bedrock. -// Supports Titan Image Generator v1, Nova Canvas v1, and Titan Image Generator v2. -// Supports three edit types: INPAINTING, OUTPAINTING, and BACKGROUND_REMOVAL. +// Supports Titan Image Generator v1, Nova Canvas v1, Titan Image Generator v2 (three edit types: +// INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL), and Stability AI edit models (inpaint, outpaint, +// recolor, search-replace, erase-object, remove-bg, control-sketch, control-structure, style-guide, +// style-transfer, upscale-creative, upscale-conservative, upscale-fast). // Returns a BifrostImageGenerationResponse containing the edited images and any error that occurred. func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ImageEditRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError + // Stability AI routing and task-type inference use the actual model ID. + path := provider.getModelPath("invoke", request.Model, key) + jsonData, bifrostError = providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageEditRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { + if isStabilityAIModel(request.Model) { + return ToStabilityAIImageEditRequest(request, request.Model) + } + return ToBedrockImageEditRequest(request) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2171,20 +2023,16 @@ func (provider *BedrockProvider) ImageEdit(ctx *schemas.BifrostContext, key sche // Parse response (reuse BedrockImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image edit response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2216,11 +2064,6 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - var jsonData []byte var bifrostError *schemas.BifrostError @@ -2229,14 +2072,13 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToBedrockImageVariationRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostError != nil { return nil, bifrostError } // Make API request (same URL as image generation) - path, deployment := provider.getModelPath("invoke", request.Model, key) + path := provider.getModelPath("invoke", request.Model, key) rawResponse, latency, providerResponseHeaders, bifrostError := provider.completeRequest(ctx, jsonData, path, key) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -2248,20 +2090,16 @@ func (provider *BedrockProvider) ImageVariation(ctx *schemas.BifrostContext, key // Parse response (reuse BedrockImageGenerationResponse and ToBifrostImageGenerationResponse) var imageResp BedrockImageGenerationResponse if err := sonic.Unmarshal(rawResponse, &imageResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing image variation response", err), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } if imageResp.Error != "" { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(imageResp.Error, nil), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Convert response and set metadata bifrostResponse := ToBifrostImageGenerationResponse(&imageResp) bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageVariationRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.ModelDeployment = deployment bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2320,13 +2158,6 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is is missing in file upload request") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2348,7 +2179,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch if s3Bucket == "" { provider.logger.Error("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)") - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } // Parse bucket name and optional prefix from s3Bucket (could be "bucket-name" or "s3://bucket-name/prefix/") @@ -2382,14 +2213,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch httpReq, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(request.File)) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } httpReq.Header.Set("Content-Type", "application/octet-stream") httpReq.ContentLength = int64(len(request.File)) // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("error signing request: %s", err.Error.Message) return nil, err } @@ -2409,14 +2240,14 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { body, _ := io.ReadAll(resp.Body) provider.logger.Error("s3 upload failed: %d", resp.StatusCode) - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 upload failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Return S3 URI as the file ID @@ -2433,9 +2264,7 @@ func (provider *BedrockProvider) FileUpload(ctx *schemas.BifrostContext, key sch StorageBackend: schemas.FileStorageS3, StorageURI: s3URI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2448,8 +2277,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - // Get S3 bucket from storage config or extra params s3Bucket := "" s3Prefix := "" @@ -2471,7 +2298,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc } if s3Bucket == "" { - return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("s3_bucket is required for Bedrock file operations (provide in storage_config.s3 or extra_params)", nil) } bucketName, bucketPrefix := parseS3URI(s3Bucket) @@ -2482,7 +2309,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -2493,10 +2320,6 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -2523,14 +2346,11 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request for S3 - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); bifrostErr != nil { return nil, bifrostErr } @@ -2549,23 +2369,23 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { - return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + return nil, providerUtils.NewProviderAPIError(fmt.Sprintf("S3 list failed: %s", string(body)), nil, resp.StatusCode, nil, nil) } // Parse S3 ListObjectsV2 XML response var listResp S3ListObjectsResponse if err := parseS3ListResponse(body, &listResp); err != nil { - return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error parsing S3 response", err) } // Convert files to Bifrost format @@ -2597,9 +2417,7 @@ func (provider *BedrockProvider) FileList(ctx *schemas.BifrostContext, keys []sc Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -2615,25 +2433,18 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2645,12 +2456,12 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodHead, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2670,13 +2481,13 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 HEAD failed with status %d", resp.StatusCode), nil, resp.StatusCode, nil, nil) continue } @@ -2706,9 +2517,7 @@ func (provider *BedrockProvider) FileRetrieve(ctx *schemas.BifrostContext, keys StorageBackend: schemas.FileStorageS3, StorageURI: request.FileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2722,25 +2531,18 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2752,12 +2554,12 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] httpReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2777,7 +2579,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } @@ -2785,7 +2587,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 DELETE failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } @@ -2796,9 +2598,7 @@ func (provider *BedrockProvider) FileDelete(ctx *schemas.BifrostContext, keys [] Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2812,25 +2612,18 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id (S3 URI) is required", nil) } // Parse S3 URI bucketName, s3Key := parseS3URI(request.FileID) if bucketName == "" || s3Key == "" { - return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid S3 URI format, expected s3://bucket/key", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -2842,12 +2635,12 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { lastErr = err continue } @@ -2867,21 +2660,21 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) resp.Body.Close() - lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, providerName, nil, nil) + lastErr = providerUtils.NewProviderAPIError(fmt.Sprintf("S3 GET failed: %s", string(body)), nil, resp.StatusCode, nil, nil) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading S3 object content", err) continue } @@ -2895,9 +2688,7 @@ func (provider *BedrockProvider) FileContent(ctx *schemas.BifrostContext, keys [ Content: body, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2912,13 +2703,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - provider.logger.Error("bedrock key config is not provided") - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Require RoleArn in extra params roleArn := "" // First we will honor the role_arn coming from the client side if present @@ -2929,14 +2713,14 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // If its empty then we will honor the role_arn from the key config if roleArn == "" { - if key.BedrockKeyConfig.ARN != nil { - roleArn = key.BedrockKeyConfig.ARN.GetValue() + if key.BedrockKeyConfig.RoleARN != nil { + roleArn = key.BedrockKeyConfig.RoleARN.GetValue() } } // And if still we don't get role ARN if roleArn == "" { provider.logger.Error("role_arn is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("role_arn is required for Bedrock batch API (provide in extra_params)", nil) } // Get output S3 URI from extra params outputS3Uri := "" @@ -2947,24 +2731,12 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } if outputS3Uri == "" { provider.logger.Error("output_s3_uri is required for Bedrock batch API (provide in extra_params)") - return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("output_s3_uri is required for Bedrock batch API (provide in extra_params)", nil) } if request.Model == nil { provider.logger.Error("model is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil, providerName) - } - - // Get model ID - - var modelID *string - if key.BedrockKeyConfig.Deployments != nil && request.Model != nil { - if deployment, ok := key.BedrockKeyConfig.Deployments[*request.Model]; ok { - modelID = schemas.Ptr(deployment) - } - } - if modelID == nil { - modelID = request.Model + return nil, providerUtils.NewBifrostOperationError("model is required for Bedrock batch API", nil) } // Generate job name @@ -2992,9 +2764,9 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // Convert inline requests to Bedrock JSONL format - jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, modelID) + jsonlData, err := ConvertBedrockRequestsToJSONL(request.Requests, request.Model) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Generate S3 key for the input file @@ -3014,7 +2786,6 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc bucket, s3Key, jsonlData, - providerName, ); bifrostErr != nil { return nil, bifrostErr } @@ -3025,13 +2796,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc // Validate that we have an input file ID (either provided or uploaded) if inputFileID == "" { provider.logger.Error("either input_file_id (S3 URI) or requests array is required for Bedrock batch API") - return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id (S3 URI) or requests array is required for Bedrock batch API", nil) } // Build request bedrockReq := &BedrockBatchJobRequest{ JobName: jobName, - ModelID: modelID, + ModelID: request.Model, RoleArn: roleArn, InputDataConfig: BedrockInputDataConfig{ S3InputDataConfig: BedrockS3InputDataConfig{ @@ -3056,7 +2827,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(bedrockReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } sendBackRawRequest := provider.sendBackRawRequest @@ -3071,11 +2842,11 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc reqURL := fmt.Sprintf("https://bedrock.%s.amazonaws.com/model-invocation-job", region) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewBuffer(jsonData)) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error creating request", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } @@ -3094,13 +2865,13 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc }, }, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error reading response", err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { @@ -3109,7 +2880,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, sendBackRawRequest, sendBackRawResponse) } // AWS CreateModelInvocationJob only returns jobArn, not status or other details. @@ -3126,9 +2897,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc InputFileID: inputFileID, Status: schemas.BatchStatusValidating, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3141,9 +2910,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc Status: retrieveResp.Status, CreatedAt: retrieveResp.CreatedAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3161,12 +2928,10 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - // Initialize serial pagination helper (Bedrock uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3177,17 +2942,9 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3210,11 +2967,11 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating request", err) } // Sign request - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); bifrostErr != nil { return nil, bifrostErr } @@ -3233,13 +2990,13 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error reading response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error reading response", err) } if resp.StatusCode != http.StatusOK { @@ -3248,7 +3005,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s var bedrockResp BedrockBatchJobListResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert batches to Bifrost format @@ -3293,9 +3050,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3335,7 +3090,7 @@ func (provider *BedrockProvider) fetchBatchManifest(ctx *schemas.BifrostContext, } // Sign request for S3 - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3", provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, "s3"); err != nil { provider.logger.Error("failed to sign manifest request: %v", err) return nil } @@ -3373,19 +3128,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3397,12 +3145,12 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3422,14 +3170,14 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3440,7 +3188,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys var bedrockResp BedrockBatchJobResponse if err := sonic.Unmarshal(body, &bedrockResp); err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) continue } @@ -3459,9 +3207,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys Status: ToBifrostBatchStatus(bedrockResp.Status), Metadata: metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3517,19 +3263,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id (job ARN) is required", nil) } var lastErr *schemas.BifrostError for _, key := range keys { - if !ensureBedrockKeyConfig(&key) { - lastErr = providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - continue - } - region := DefaultBedrockRegion if key.BedrockKeyConfig.Region != nil && key.BedrockKeyConfig.Region.GetValue() != "" { region = key.BedrockKeyConfig.Region.GetValue() @@ -3541,12 +3280,12 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error creating request", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error creating request", err) continue } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3566,14 +3305,14 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ }, } } - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if err != nil { - lastErr = providerUtils.NewBifrostOperationError("error reading response", err, providerName) + lastErr = providerUtils.NewBifrostOperationError("error reading response", err) continue } @@ -3596,9 +3335,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: schemas.BatchStatusCancelling, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: totalLatency.Milliseconds(), + Latency: totalLatency.Milliseconds(), }, }, nil } @@ -3608,9 +3345,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ Object: "batch", Status: retrieveResp.Status, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3631,8 +3366,6 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output S3 URI prefix (using all keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -3643,7 +3376,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output S3 URI is empty (batch may not be completed)", nil) } outputS3URI := *batchResp.OutputFileID @@ -3685,7 +3418,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys if directErr != nil { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to access batch results at %s: listing failed and direct access failed", outputS3URI), - nil, providerName) + nil) } // Direct download succeeded, parse the content @@ -3694,9 +3427,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: fileContentResp.ExtraFields.Latency, + Latency: fileContentResp.ExtraFields.Latency, }, } if len(parseErrors) > 0 { @@ -3729,9 +3460,7 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys BatchID: request.BatchID, Results: allResults, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: totalLatency, + Latency: totalLatency, }, } @@ -3742,26 +3471,14 @@ func (provider *BedrockProvider) BatchResults(ctx *schemas.BifrostContext, keys return batchResultsResp, nil } -func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) (string, string) { - deployment := resolveBedrockDeployment(model, key) - // Default: use model/deployment directly - path := fmt.Sprintf("%s/%s", deployment, basePath) +func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) string { + path := fmt.Sprintf("%s/%s", model, basePath) // If ARN is present, Bedrock expects the ARN-scoped identifier if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.ARN != nil && key.BedrockKeyConfig.ARN.GetValue() != "" { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), deployment)) + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", key.BedrockKeyConfig.ARN.GetValue(), model)) path = fmt.Sprintf("%s/%s", encodedModelIdentifier, basePath) } - return path, deployment -} - -func resolveBedrockDeployment(model string, key schemas.Key) string { - deployment := model - if key.BedrockKeyConfig != nil && key.BedrockKeyConfig.Deployments != nil { - if mapped, ok := key.BedrockKeyConfig.Deployments[model]; ok && mapped != "" { - deployment = mapped - } - } - return deployment + return path } func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { @@ -3769,16 +3486,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc return nil, err } - providerName := provider.GetProviderKey() - - if !ensureBedrockKeyConfig(&key) { - return nil, providerUtils.NewConfigurationError("bedrock key config is not provided", providerName) - } - // Convert to Bedrock Converse format using the existing responses converter converseReq, convErr := ToBedrockResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, convErr) } // Wrap in the CountTokens request envelope @@ -3787,11 +3498,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc jsonData, err := providerUtils.MarshalSorted(countTokensReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Format the path with proper model identifier - path, deployment := provider.getModelPath("count-tokens", request.Model, key) + path := provider.getModelPath("count-tokens", request.Model, key) // Send the request responseBody, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, jsonData, path, key) @@ -3802,15 +3513,11 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc if isCountTokensUnsupported(bifrostErr) { estimated := estimateTokenCount(jsonData) return &schemas.BifrostCountTokensResponse{ - Model: deployment, + Model: request.Model, InputTokens: estimated, TotalTokens: &estimated, Object: "response.input_tokens", ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - ModelRequested: request.Model, - ModelDeployment: deployment, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3833,15 +3540,10 @@ func (provider *BedrockProvider) CountTokens(ctx *schemas.BifrostContext, key sc } // Convert to Bifrost format - response := bedrockResponse.ToBifrostCountTokensResponse(deployment) + response := bedrockResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.ModelDeployment = deployment response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders - if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequest } diff --git a/core/providers/bedrock/bedrock_test.go b/core/providers/bedrock/bedrock_test.go index adc2cdf937..4e1c659cd5 100644 --- a/core/providers/bedrock/bedrock_test.go +++ b/core/providers/bedrock/bedrock_test.go @@ -181,22 +181,22 @@ func TestBedrock(t *testing.T) { {Provider: schemas.Bedrock, Model: "claude-4-sonnet"}, {Provider: schemas.Bedrock, Model: "claude-4.5-sonnet"}, }, - EmbeddingModel: "cohere.embed-v4:0", - RerankModel: rerankModelARN, - ReasoningModel: "claude-4.5-sonnet", - PromptCachingModel: "claude-4.5-sonnet", - ImageEditModel: "amazon.nova-canvas-v1:0", - ImageVariationModel: "amazon.nova-canvas-v1:0", + EmbeddingModel: "cohere.embed-v4:0", + RerankModel: rerankModelARN, + ReasoningModel: "claude-4.5-sonnet", + PromptCachingModel: "claude-4.5-sonnet", + ImageEditModel: "amazon.nova-canvas-v1:0", + ImageVariationModel: "amazon.nova-canvas-v1:0", InterleavedThinkingModel: "global.anthropic.claude-opus-4-5-20251101-v1:0", - BatchExtraParams: batchExtraParams, - FileExtraParams: fileExtraParams, + BatchExtraParams: batchExtraParams, + FileExtraParams: fileExtraParams, Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, End2EndToolCalling: true, @@ -228,6 +228,9 @@ func TestBedrock(t *testing.T) { ImageVariation: true, StructuredOutputs: true, InterleavedThinking: true, + EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (per B-header) + // ServerToolsViaOpenAIEndpoint: Bedrock does not support web_search / web_fetch / + // code_execution server tools per Table 20, so no cases would run. Left off. }, } @@ -1262,7 +1265,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: "tool-use-123", Name: "get_weather", - Input: json.RawMessage(`{"location":"NYC"}`), + Input: json.RawMessage(`{"location":"NYC"}`), }, }, }, @@ -1337,7 +1340,7 @@ func TestBedrockToBifrostRequestConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: "tool-use-456", Name: "calculate", - Input: json.RawMessage(`{"expression":"2+2"}`), + Input: json.RawMessage(`{"expression":"2+2"}`), }, }, }, @@ -1866,7 +1869,7 @@ func TestBifrostToBedrockResponseConversion(t *testing.T) { ToolUse: &bedrock.BedrockToolUse{ ToolUseID: "call-111", Name: "get_weather", - Input: json.RawMessage(`{"location":"NYC"}`), + Input: json.RawMessage(`{"location":"NYC"}`), }, }, { @@ -2228,6 +2231,173 @@ func TestToBedrockResponsesRequest_AdditionalFields_InterfaceSlice(t *testing.T) assert.Equal(t, []string{"/amazon-bedrock-invocationMetrics/inputTokenCount"}, bedrockReq.AdditionalModelResponseFieldPaths) } +func TestToBedrockResponsesRequest_AnthropicTextFormatUsesOutputConfig(t *testing.T) { + schemaObj := any(schemas.NewOrderedMapFromPairs( + schemas.KV("type", "object"), + schemas.KV("properties", schemas.NewOrderedMapFromPairs( + schemas.KV("topic", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + )), + )), + schemas.KV("required", []string{"topic"}), + )) + + req := &schemas.BifrostResponsesRequest{ + Model: "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + Params: &schemas.ResponsesParameters{ + Text: &schemas.ResponsesTextConfig{ + Format: &schemas.ResponsesTextConfigFormat{ + Type: "json_schema", + Name: schemas.Ptr("classification"), + JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{ + Schema: &schemaObj, + }, + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bedrockReq, err := bedrock.ToBedrockResponsesRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, bedrockReq) + require.NotNil(t, bedrockReq.AdditionalModelRequestFields, "expected additional model request fields for anthropic responses structured output") + + outputConfigRaw, hasOutputConfig := bedrockReq.AdditionalModelRequestFields.Get("output_config") + require.True(t, hasOutputConfig, "expected output_config for anthropic responses structured output") + + outputConfig, ok := schemas.SafeExtractOrderedMap(outputConfigRaw) + require.True(t, ok, "expected output_config to be an ordered map") + + formatRaw, hasFormat := outputConfig.Get("format") + require.True(t, hasFormat, "expected output_config.format") + + formatMap, ok := schemas.SafeExtractOrderedMap(formatRaw) + require.True(t, ok, "expected output_config.format to be an ordered map") + + formatType, ok := formatMap.Get("type") + require.True(t, ok, "expected output_config.format.type") + assert.Equal(t, "json_schema", formatType) + + schemaRaw, ok := formatMap.Get("schema") + require.True(t, ok, "expected output_config.format.schema") + schemaMap, ok := schemas.SafeExtractOrderedMap(schemaRaw) + require.True(t, ok, "expected output_config.format.schema to remain ordered") + require.NotNil(t, schemaMap) + + if bedrockReq.ToolConfig != nil { + assert.Nil(t, bedrockReq.ToolConfig.ToolChoice, "expected no forced tool choice for anthropic responses structured output") + assert.Empty(t, bedrockReq.ToolConfig.Tools, "expected no synthetic structured output tool for anthropic responses structured output") + } +} + +func TestToBedrockResponsesRequest_NonAnthropicTextFormatStillUsesToolConversion(t *testing.T) { + schemaObj := any(schemas.NewOrderedMapFromPairs( + schemas.KV("type", "object"), + schemas.KV("properties", schemas.NewOrderedMapFromPairs( + schemas.KV("topic", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + )), + )), + schemas.KV("required", []string{"topic"}), + )) + + req := &schemas.BifrostResponsesRequest{ + Model: "bedrock/amazon.nova-pro-v1:0", + Params: &schemas.ResponsesParameters{ + Text: &schemas.ResponsesTextConfig{ + Format: &schemas.ResponsesTextConfigFormat{ + Type: "json_schema", + Name: schemas.Ptr("classification"), + JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{ + Schema: &schemaObj, + }, + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bedrockReq, err := bedrock.ToBedrockResponsesRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, bedrockReq) + + if bedrockReq.AdditionalModelRequestFields != nil { + _, hasOutputConfig := bedrockReq.AdditionalModelRequestFields.Get("output_config") + assert.False(t, hasOutputConfig, "expected no output_config for non-anthropic responses structured output") + } + + require.NotNil(t, bedrockReq.ToolConfig, "expected tool_config for non-anthropic responses structured output") + require.NotEmpty(t, bedrockReq.ToolConfig.Tools, "expected synthetic structured output tool to be added") + require.NotNil(t, bedrockReq.ToolConfig.ToolChoice, "expected structured output tool choice to be forced") + require.NotNil(t, bedrockReq.ToolConfig.ToolChoice.Tool, "expected structured output tool choice to target the synthetic tool") + assert.Contains(t, bedrockReq.ToolConfig.ToolChoice.Tool.Name, "bf_so_", "expected forced tool choice to target the synthetic structured output tool") +} + +func TestToBedrockResponsesRequest_NonAnthropicTextFormatPreservedWithUserTools(t *testing.T) { + schemaObj := any(schemas.NewOrderedMapFromPairs( + schemas.KV("type", "object"), + schemas.KV("properties", schemas.NewOrderedMapFromPairs( + schemas.KV("topic", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + )), + )), + schemas.KV("required", []string{"topic"}), + )) + + toolParams := schemas.ToolFunctionParameters{ + Type: "object", + Properties: schemas.NewOrderedMapFromPairs( + schemas.KV("city", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + )), + ), + } + + req := &schemas.BifrostResponsesRequest{ + Model: "bedrock/amazon.nova-pro-v1:0", + Params: &schemas.ResponsesParameters{ + Text: &schemas.ResponsesTextConfig{ + Format: &schemas.ResponsesTextConfigFormat{ + Type: "json_schema", + Name: schemas.Ptr("classification"), + JSONSchema: &schemas.ResponsesTextConfigFormatJSONSchema{ + Schema: &schemaObj, + }, + }, + }, + Tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("get_weather"), + Description: schemas.Ptr("Get weather information"), + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Parameters: &toolParams, + }, + }, + }, + ToolChoice: &schemas.ResponsesToolChoice{ + ResponsesToolChoiceStruct: &schemas.ResponsesToolChoiceStruct{ + Type: schemas.ResponsesToolChoiceTypeFunction, + Name: schemas.Ptr("get_weather"), + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bedrockReq, err := bedrock.ToBedrockResponsesRequest(ctx, req) + require.NoError(t, err) + require.NotNil(t, bedrockReq) + require.NotNil(t, bedrockReq.ToolConfig, "expected tool_config to be initialized") + require.Len(t, bedrockReq.ToolConfig.Tools, 2, "expected synthetic structured output tool plus user tool") + require.NotNil(t, bedrockReq.ToolConfig.ToolChoice, "expected structured output tool choice to be forced") + require.NotNil(t, bedrockReq.ToolConfig.ToolChoice.Tool, "expected structured output tool choice to target the synthetic tool") + assert.Equal(t, "bf_so_classification", bedrockReq.ToolConfig.ToolChoice.Tool.Name) + assert.Equal(t, "bf_so_classification", bedrockReq.ToolConfig.Tools[0].ToolSpec.Name) + assert.Equal(t, "get_weather", bedrockReq.ToolConfig.Tools[1].ToolSpec.Name) +} + // TestToolResultJSONParsingResponsesAPI tests that tool results are correctly parsed and wrapped based on JSON type // Tests only Responses API. func TestToolResultJSONParsingResponsesAPI(t *testing.T) { @@ -2254,7 +2424,7 @@ func TestToolResultJSONParsingResponsesAPI(t *testing.T) { name: "JSONObjectResult", toolResultContent: `{"location":"NYC","temperature":72}`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"location": "NYC", "temperature": float64(72)}), + expectedJSON: mustMarshalJSON(map[string]any{"location": "NYC", "temperature": float64(72)}), }, { name: "JSONArrayResult", @@ -2271,37 +2441,37 @@ func TestToolResultJSONParsingResponsesAPI(t *testing.T) { name: "JSONPrimitiveNumberResult", toolResultContent: `42`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": float64(42)}), + expectedJSON: mustMarshalJSON(map[string]any{"value": float64(42)}), }, { name: "JSONPrimitiveStringResult", toolResultContent: `"hello world"`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": "hello world"}), + expectedJSON: mustMarshalJSON(map[string]any{"value": "hello world"}), }, { name: "JSONPrimitiveBooleanResult", toolResultContent: `true`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": true}), + expectedJSON: mustMarshalJSON(map[string]any{"value": true}), }, { name: "JSONPrimitiveNullResult", toolResultContent: `null`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"value": nil}), + expectedJSON: mustMarshalJSON(map[string]any{"value": nil}), }, { name: "EmptyJSONObjectResult", toolResultContent: `{}`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{}), + expectedJSON: mustMarshalJSON(map[string]any{}), }, { name: "EmptyJSONArrayResult", toolResultContent: `[]`, expectedContentType: "json", - expectedJSON: mustMarshalJSON(map[string]any{"results": []any{}}), + expectedJSON: mustMarshalJSON(map[string]any{"results": []any{}}), }, } @@ -2893,6 +3063,379 @@ func TestAnthropicReasoningConfigUsesThinkingField(t *testing.T) { } } +func TestAnthropicOrderedOutputConfigRoundTripsReasoning(t *testing.T) { + request := &bedrock.BedrockConverseRequest{ + ModelID: "anthropic.claude-opus-4-6-v1", + Messages: []bedrock.BedrockMessage{ + { + Role: bedrock.BedrockMessageRoleUser, + Content: []bedrock.BedrockContentBlock{ + { + Text: schemas.Ptr("Hello"), + }, + }, + }, + }, + AdditionalModelRequestFields: schemas.NewOrderedMapFromPairs( + schemas.KV("thinking", map[string]any{ + "type": "adaptive", + "budget_tokens": 2048, + }), + schemas.KV("output_config", schemas.NewOrderedMapFromPairs( + schemas.KV("effort", "medium"), + )), + ), + ExtraParams: map[string]any{ + "reasoning_summary": "auto", + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + result, err := request.ToBifrostResponsesRequest(ctx) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Params) + require.NotNil(t, result.Params.Reasoning) + require.NotNil(t, result.Params.Reasoning.Effort) + assert.Equal(t, "medium", *result.Params.Reasoning.Effort) + require.NotNil(t, result.Params.Reasoning.MaxTokens) + assert.Equal(t, 2048, *result.Params.Reasoning.MaxTokens) + require.NotNil(t, result.Params.Reasoning.Summary) + assert.Equal(t, "auto", *result.Params.Reasoning.Summary) +} + +func TestAnthropicOutputConfigFormatStillFallsBackToBudgetTokensForReasoning(t *testing.T) { + request := &bedrock.BedrockConverseRequest{ + ModelID: "anthropic.claude-opus-4-6-v1", + Messages: []bedrock.BedrockMessage{ + { + Role: bedrock.BedrockMessageRoleUser, + Content: []bedrock.BedrockContentBlock{ + { + Text: schemas.Ptr("Hello"), + }, + }, + }, + }, + AdditionalModelRequestFields: schemas.NewOrderedMapFromPairs( + schemas.KV("thinking", map[string]any{ + "type": "adaptive", + "budget_tokens": 2048, + }), + schemas.KV("output_config", schemas.NewOrderedMapFromPairs( + schemas.KV("format", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "json_schema"), + schemas.KV("schema", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "object"), + )), + )), + )), + ), + ExtraParams: map[string]any{ + "reasoning_summary": "auto", + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + result, err := request.ToBifrostResponsesRequest(ctx) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Params) + require.NotNil(t, result.Params.Reasoning) + require.NotNil(t, result.Params.Reasoning.Effort) + // Effort is inferred from budget_tokens (2048) against the model-specific max output tokens + // (128K for claude-opus-4-6) minus Anthropic's minimum reasoning budget (1024). That ratio + // (~0.008) falls in the "low" bucket — see providerUtils.GetReasoningEffortFromBudgetTokens. + assert.Equal(t, "low", *result.Params.Reasoning.Effort) + require.NotNil(t, result.Params.Reasoning.MaxTokens) + assert.Equal(t, 2048, *result.Params.Reasoning.MaxTokens) + require.NotNil(t, result.Params.Reasoning.Summary) + assert.Equal(t, "auto", *result.Params.Reasoning.Summary) +} + +// TestAnthropicStructuredOutputUsesOutputConfigWithoutForcedToolChoice ensures +// Anthropic Bedrock structured output uses native output_config.format and does +// not synthesize a forced tool choice, while keeping reasoning (thinking) active. +func TestAnthropicStructuredOutputUsesOutputConfigWithoutForcedToolChoice(t *testing.T) { + responseFormat := any(map[string]any{ + "type": "json_schema", + "json_schema": map[string]any{ + "name": "classification", + "schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "topic": map[string]any{ + "type": "string", + }, + }, + "required": []any{"topic"}, + }, + }, + }) + + bifrostReq := &schemas.BifrostChatRequest{ + Model: "anthropic.claude-3-7-sonnet-v1", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Classify this"), + }, + }, + }, + Params: &schemas.ChatParameters{ + ResponseFormat: &responseFormat, + Reasoning: &schemas.ChatReasoning{ + MaxTokens: schemas.Ptr(2048), + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + result, err := bedrock.ToBedrockChatCompletionRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.AdditionalModelRequestFields) + + outputConfigRaw, hasOutputConfig := result.AdditionalModelRequestFields.Get("output_config") + require.True(t, hasOutputConfig, "expected output_config for anthropic structured output") + + outputConfig, ok := outputConfigRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config to be an ordered map") + + formatRaw, hasFormat := outputConfig.Get("format") + require.True(t, hasFormat, "expected output_config.format") + + format, ok := formatRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config.format to be an ordered map") + formatType, hasType := format.Get("type") + require.True(t, hasType, "expected output_config.format.type") + assert.Equal(t, "json_schema", formatType) + _, hasSchema := format.Get("schema") + assert.True(t, hasSchema, "expected output_config.format.schema") + + // reasoning should still be preserved for anthropic + thinkingRaw, hasThinking := result.AdditionalModelRequestFields.Get("thinking") + require.True(t, hasThinking, "expected thinking field for anthropic reasoning") + thinking, ok := thinkingRaw.(map[string]any) + require.True(t, ok, "expected thinking to be a map") + assert.Equal(t, "enabled", thinking["type"]) + + // structured output should NOT force tool choice on Bedrock anthropic + if result.ToolConfig != nil { + assert.Nil(t, result.ToolConfig.ToolChoice, "expected no forced tool choice for anthropic structured output") + assert.Empty(t, result.ToolConfig.Tools, "expected no synthetic structured output tool for anthropic structured output") + } +} + +func TestAnthropicStructuredOutputAcceptsOrderedMaps(t *testing.T) { + responseFormat := any(schemas.NewOrderedMapFromPairs( + schemas.KV("type", "json_schema"), + schemas.KV("json_schema", schemas.NewOrderedMapFromPairs( + schemas.KV("name", "classification"), + schemas.KV("schema", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "object"), + schemas.KV("description", "Return structured classification"), + schemas.KV("properties", schemas.NewOrderedMapFromPairs( + schemas.KV("topic", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + )), + )), + schemas.KV("required", []any{"topic"}), + )), + )), + )) + + bifrostReq := &schemas.BifrostChatRequest{ + Model: "anthropic.claude-3-7-sonnet-v1", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Classify this"), + }, + }, + }, + Params: &schemas.ChatParameters{ + ResponseFormat: &responseFormat, + Reasoning: &schemas.ChatReasoning{ + MaxTokens: schemas.Ptr(2048), + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + result, err := bedrock.ToBedrockChatCompletionRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.AdditionalModelRequestFields) + + outputConfigRaw, hasOutputConfig := result.AdditionalModelRequestFields.Get("output_config") + require.True(t, hasOutputConfig, "expected output_config for anthropic structured output") + + outputConfig, ok := outputConfigRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config to be an ordered map") + + formatRaw, hasFormat := outputConfig.Get("format") + require.True(t, hasFormat, "expected output_config.format") + + format, ok := formatRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config.format to be an ordered map") + + formatType, ok := format.Get("type") + require.True(t, ok, "expected output_config.format.type") + assert.Equal(t, "json_schema", formatType) + + schemaRaw, ok := format.Get("schema") + require.True(t, ok, "expected output_config.format.schema") + _, ok = schemaRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config.format.schema to remain ordered") +} + +// TestNonAnthropicStructuredOutputStillUsesToolConversion ensures Bedrock models +// other than Anthropic continue to use the legacy response_format->tool path. +func TestNonAnthropicStructuredOutputStillUsesToolConversion(t *testing.T) { + responseFormat := any(schemas.NewOrderedMapFromPairs( + schemas.KV("type", "json_schema"), + schemas.KV("json_schema", schemas.NewOrderedMapFromPairs( + schemas.KV("name", "classification"), + schemas.KV("schema", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "object"), + schemas.KV("description", "Return structured classification"), + schemas.KV("properties", schemas.NewOrderedMapFromPairs( + schemas.KV("topic", schemas.NewOrderedMapFromPairs( + schemas.KV("type", "string"), + )), + )), + schemas.KV("required", []any{"topic"}), + )), + )), + )) + + bifrostReq := &schemas.BifrostChatRequest{ + Model: "amazon.nova-pro-v1", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Classify this"), + }, + }, + }, + Params: &schemas.ChatParameters{ + ResponseFormat: &responseFormat, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + result, err := bedrock.ToBedrockChatCompletionRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + + // Non-Anthropic models should not use output_config.format. + if result.AdditionalModelRequestFields != nil { + _, hasOutputConfig := result.AdditionalModelRequestFields.Get("output_config") + assert.False(t, hasOutputConfig, "expected no output_config for non-anthropic structured output") + } + + require.NotNil(t, result.ToolConfig, "expected tool_config for non-anthropic structured output") + require.NotEmpty(t, result.ToolConfig.Tools, "expected synthetic structured output tool to be added") + require.NotNil(t, result.ToolConfig.ToolChoice, "expected structured output tool choice to be forced") + require.NotNil(t, result.ToolConfig.ToolChoice.Tool, "expected structured output tool choice to target the synthetic tool") + assert.Equal(t, "bf_so_classification", result.ToolConfig.ToolChoice.Tool.Name) + assert.Equal(t, "bf_so_classification", result.ToolConfig.Tools[0].ToolSpec.Name) + + schemaRaw := result.ToolConfig.Tools[0].ToolSpec.InputSchema.JSON + var schema schemas.OrderedMap + require.NoError(t, schema.UnmarshalJSON(schemaRaw)) + schemaType, ok := schema.Get("type") + require.True(t, ok, "expected tool schema type") + assert.Equal(t, "object", schemaType) +} + +// TestAnthropicStructuredOutputMergesAdditionalModelRequestFieldPaths ensures +// additionalModelRequestFieldPaths are merged into existing AdditionalModelRequestFields +// and output_config is deep-merged instead of overwritten. +func TestAnthropicStructuredOutputMergesAdditionalModelRequestFieldPaths(t *testing.T) { + responseFormat := any(map[string]any{ + "type": "json_schema", + "json_schema": map[string]any{ + "name": "classification", + "schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "topic": map[string]any{ + "type": "string", + }, + }, + "required": []any{"topic"}, + }, + }, + }) + + bifrostReq := &schemas.BifrostChatRequest{ + Model: "anthropic.claude-3-7-sonnet-v1", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Classify this"), + }, + }, + }, + Params: &schemas.ChatParameters{ + ResponseFormat: &responseFormat, + Reasoning: &schemas.ChatReasoning{ + MaxTokens: schemas.Ptr(2048), + }, + ExtraParams: map[string]any{ + "additionalModelRequestFieldPaths": schemas.NewOrderedMapFromPairs( + schemas.KV("output_config", map[string]any{ + "foo": "bar", + }), + schemas.KV("customField", "customValue"), + ), + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + result, err := bedrock.ToBedrockChatCompletionRequest(ctx, bifrostReq) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.AdditionalModelRequestFields) + + outputConfigRaw, hasOutputConfig := result.AdditionalModelRequestFields.Get("output_config") + require.True(t, hasOutputConfig, "expected output_config to exist after merge") + outputConfig, ok := outputConfigRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config to be an ordered map") + + // Existing structured output format must be preserved. + formatRaw, hasFormat := outputConfig.Get("format") + require.True(t, hasFormat, "expected output_config.format to be preserved") + format, ok := formatRaw.(*schemas.OrderedMap) + require.True(t, ok, "expected output_config.format to be an ordered map") + formatType, hasType := format.Get("type") + require.True(t, hasType, "expected output_config.format.type") + assert.Equal(t, "json_schema", formatType) + _, hasSchema := format.Get("schema") + assert.True(t, hasSchema, "expected output_config.format.schema") + + // Incoming additionalModelRequestFieldPaths.output_config key must be merged. + foo, hasFoo := outputConfig.Get("foo") + require.True(t, hasFoo, "expected output_config.foo to be preserved") + assert.Equal(t, "bar", foo) + + // Existing top-level field (thinking) must not be lost. + _, hasThinking := result.AdditionalModelRequestFields.Get("thinking") + assert.True(t, hasThinking, "expected thinking to be preserved") + + // Incoming top-level keys must be merged. + customField, hasCustomField := result.AdditionalModelRequestFields.Get("customField") + require.True(t, hasCustomField, "expected customField to be merged") + assert.Equal(t, "customValue", customField) +} + // TestNovaReasoningConfigUsesReasoningConfigField verifies that Nova models use // the "reasoningConfig" field (camelCase) and NOT "thinking". func TestNovaReasoningConfigUsesReasoningConfigField(t *testing.T) { @@ -3632,22 +4175,22 @@ func TestToBedrockInvokeMessagesStreamResponse_NoDuplicateContentBlockStop(t *te { Type: schemas.ResponsesStreamResponseTypeOutputTextDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeContentPartDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, { Type: schemas.ResponsesStreamResponseTypeOutputItemDone, ContentIndex: &contentIdx, - ExtraFields: schemas.BifrostResponseExtraFields{ModelRequested: model}, + ExtraFields: schemas.BifrostResponseExtraFields{OriginalModelRequested: model}, }, } type bedrockChunk struct { - InvokeModelRawChunk []byte `json:"invokeModelRawChunk"` + InvokeModelRawChunks [][]byte `json:"invokeModelRawChunks"` } var stopCount int @@ -3661,9 +4204,10 @@ func TestToBedrockInvokeMessagesStreamResponse_NoDuplicateContentBlockStop(t *te require.NoError(t, err) var chunk bedrockChunk require.NoError(t, json.Unmarshal(raw, &chunk)) - if len(chunk.InvokeModelRawChunk) > 0 && - strings.Contains(string(chunk.InvokeModelRawChunk), "content_block_stop") { - stopCount++ + for _, rawChunk := range chunk.InvokeModelRawChunks { + if strings.Contains(string(rawChunk), "content_block_stop") { + stopCount++ + } } } diff --git a/core/providers/bedrock/chat.go b/core/providers/bedrock/chat.go index 71e7890935..6459df377b 100644 --- a/core/providers/bedrock/chat.go +++ b/core/providers/bedrock/chat.go @@ -247,8 +247,6 @@ func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Conte Usage: usage, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Bedrock, }, } diff --git a/core/providers/bedrock/convert_tool_config_test.go b/core/providers/bedrock/convert_tool_config_test.go new file mode 100644 index 0000000000..fc417394e5 --- /dev/null +++ b/core/providers/bedrock/convert_tool_config_test.go @@ -0,0 +1,477 @@ +package bedrock + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestConvertToolConfig_DropsServerToolsOnBedrock locks in the bug fix from +// the user-reported repro: sending `web_search_20260209` via the OpenAI- +// compatible /v1/chat/completions endpoint to Bedrock was producing a +// malformed ToolConfig that Bedrock rejected with 400 "The provided request +// is not valid". The fix strips unsupported server tools before the +// conversion loop so the outbound request is valid. +func TestConvertToolConfig_DropsServerToolsOnBedrock(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: schemas.Ptr("Get weather by city"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + }, + }, + }, + { + // Server tool — Bedrock doesn't support web_search per Table 20. + // Should be stripped silently. + Type: schemas.ChatToolType("web_search_20260209"), + Name: "web_search", + }, + }, + } + + cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params) + if cfg == nil { + t.Fatalf("expected ToolConfig, got nil (function tool should have survived)") + } + if len(cfg.Tools) != 1 { + t.Fatalf("expected exactly 1 tool (function), got %d: %+v", len(cfg.Tools), cfg.Tools) + } + if cfg.Tools[0].ToolSpec == nil || cfg.Tools[0].ToolSpec.Name != "get_weather" { + t.Errorf("expected function tool 'get_weather' to survive, got %+v", cfg.Tools[0]) + } +} + +// TestConvertToolConfig_ReturnsNilWhenAllDropped locks in the empty-slice +// guard. Bedrock's Converse API rejects `"toolConfig": {"tools": []}` with a +// 400; when every tool is unsupported and gets stripped, convertToolConfig +// must return nil so no ToolConfig ships at all. +func TestConvertToolConfig_ReturnsNilWhenAllDropped(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("web_search_20260209"), + Name: "web_search", + }, + { + Type: schemas.ChatToolType("web_fetch_20260309"), + Name: "web_fetch", + }, + { + Type: schemas.ChatToolType("code_execution_20250825"), + Name: "code_execution", + }, + }, + } + + cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params) + if cfg != nil { + t.Fatalf("expected nil ToolConfig (all tools unsupported on Bedrock), got %+v", cfg) + } +} + +// TestConvertToolConfig_KeepsBedrockSupportedServerTools — locks in that +// Bedrock-supported server tools (bash, memory, text_editor, computer, +// tool_search) do NOT appear in Converse's typed toolConfig.tools slot — +// they must be tunneled via additionalModelRequestFields (exercised in +// TestCollectBedrockServerTools_*). If the only tool is a server tool, +// toolConfig is nil so we don't ship {"toolConfig": {"tools": []}}. +func TestConvertToolConfig_KeepsBedrockSupportedServerTools(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("bash_20250124"), + Name: "bash", + }, + }, + } + + cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params) + if cfg != nil { + t.Fatalf("expected nil toolConfig (server tools flow via additionalModelRequestFields, not toolSpec), got %+v", cfg) + } +} + +// TestCollectBedrockServerTools_BashOnly — bash is Bedrock-supported per the +// B-header list; the helper must emit it as a native-JSON tool entry with no +// derived beta header (bash has no high-confidence 1:1 beta-header mapping; +// callers rely on extra-headers for that). +func TestCollectBedrockServerTools_BashOnly(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("bash_20250124"), + Name: "bash", + }, + }, + } + tools, betas := collectBedrockServerTools(params) + if len(tools) != 1 { + t.Fatalf("expected 1 server tool, got %d", len(tools)) + } + got := string(tools[0]) + if !strings.Contains(got, `"type":"bash_20250124"`) || !strings.Contains(got, `"name":"bash"`) { + t.Errorf("expected native Anthropic bash shape, got %s", got) + } + if len(betas) != 0 { + t.Errorf("expected no derived beta headers for bash (no 1:1 mapping), got %v", betas) + } +} + +// TestCollectBedrockServerTools_ComputerDerivesBeta — computer_YYYYMMDD must +// derive computer-use-YYYY-MM-DD as the beta header, gated through +// FilterBetaHeadersForProvider(Bedrock) which keeps computer-use-* headers. +func TestCollectBedrockServerTools_ComputerDerivesBeta(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("computer_20251124"), + Name: "computer", + DisplayWidthPx: schemas.Ptr(1280), + DisplayHeightPx: schemas.Ptr(800), + }, + }, + } + tools, betas := collectBedrockServerTools(params) + if len(tools) != 1 { + t.Fatalf("expected 1 server tool, got %d", len(tools)) + } + if !strings.Contains(string(tools[0]), `"display_width_px":1280`) { + t.Errorf("expected computer variant fields to flow through, got %s", string(tools[0])) + } + if len(betas) != 1 || betas[0] != "computer-use-2025-11-24" { + t.Errorf("expected [computer-use-2025-11-24], got %v", betas) + } +} + +// TestCollectBedrockServerTools_MemoryDerivesContextManagement — memory +// activates via the context-management-2025-06-27 bundle on Bedrock (cite: +// anthropic/types.go:179). +func TestCollectBedrockServerTools_MemoryDerivesContextManagement(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("memory_20250818"), + Name: "memory", + }, + }, + } + _, betas := collectBedrockServerTools(params) + if len(betas) != 1 || betas[0] != "context-management-2025-06-27" { + t.Errorf("expected [context-management-2025-06-27], got %v", betas) + } +} + +// TestCollectBedrockServerTools_StripsUnsupported — web_search isn't in +// Bedrock's ProviderFeatures (WebSearch=false), so ValidateChatToolsForProvider +// drops it and the helper must emit nothing. +func TestCollectBedrockServerTools_StripsUnsupported(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("web_search_20260209"), + Name: "web_search", + }, + }, + } + tools, betas := collectBedrockServerTools(params) + if len(tools) != 0 { + t.Errorf("expected no server tools (web_search unsupported on Bedrock), got %d", len(tools)) + } + if len(betas) != 0 { + t.Errorf("expected no betas when all tools filtered, got %v", betas) + } +} + +// TestCollectBedrockServerTools_FunctionToolsIgnored — function/custom tools +// go through convertToolConfig, not this helper. +func TestCollectBedrockServerTools_FunctionToolsIgnored(t *testing.T) { + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + }, + }, + }, + }, + } + tools, betas := collectBedrockServerTools(params) + if len(tools) != 0 || len(betas) != 0 { + t.Errorf("function tools should not flow through server-tool helper, got tools=%d betas=%v", len(tools), betas) + } +} + +// TestBuildBedrockServerToolChoice_PinnedServerTool — caller pins a kept +// server tool (computer) by name. Converse's typed toolConfig.toolChoice path +// can't carry this because toolConfig.tools doesn't include server tools; the +// existing reconciliation silently drops the pin. The tunneled path must +// emit {"type":"tool","name":"computer"} into additionalModelRequestFields. +func TestBuildBedrockServerToolChoice_PinnedServerTool(t *testing.T) { + computer := schemas.ChatTool{ + Type: schemas.ChatToolType("computer_20251124"), + Name: "computer", + DisplayWidthPx: schemas.Ptr(1280), + } + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{computer}, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ + Type: schemas.ChatToolChoiceTypeFunction, + Function: &schemas.ChatToolChoiceFunction{Name: "computer"}, + }, + }, + } + choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{computer}) + if !ok { + t.Fatalf("expected tunneled tool_choice for pinned server tool, got (nil, false)") + } + got := string(choice) + if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) { + t.Errorf("expected Anthropic-native {type:tool,name:computer}, got %s", got) + } +} + +// TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled — function +// tool pins stay on Converse's typed path (toolConfig.toolChoice.tool). The +// helper must not double-emit. +func TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled(t *testing.T) { + fn := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Parameters: &schemas.ToolFunctionParameters{Type: "object"}, + }, + } + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{fn}, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ + Type: schemas.ChatToolChoiceTypeFunction, + Function: &schemas.ChatToolChoiceFunction{Name: "get_weather"}, + }, + }, + } + if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn}); ok { + t.Errorf("expected no tunneling for function-tool pin (typed Converse path handles it)") + } +} + +// TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools — tool_choice:any +// with only server tools: convertToolConfig returns nil (bedrockTools empty), +// so the typed any-contract is lost. The tunneled path must emit +// {"type":"any"} to preserve the forcing semantics. +func TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools(t *testing.T) { + bash := schemas.ChatTool{ + Type: schemas.ChatToolType("bash_20250124"), + Name: "bash", + } + anyStr := string(schemas.ChatToolChoiceTypeAny) + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{bash}, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStr: &anyStr, + }, + } + choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{bash}) + if !ok { + t.Fatalf("expected tunneled any-contract when only server tools are present, got (nil, false)") + } + got := string(choice) + if !strings.Contains(got, `"type":"any"`) { + t.Errorf("expected {type:any}, got %s", got) + } +} + +// TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled — when at +// least one function/custom tool is present, Converse's typed +// toolConfig.toolChoice.any carries the any-contract. Don't double-emit. +func TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled(t *testing.T) { + fn := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Parameters: &schemas.ToolFunctionParameters{Type: "object"}, + }, + } + bash := schemas.ChatTool{ + Type: schemas.ChatToolType("bash_20250124"), + Name: "bash", + } + anyStr := string(schemas.ChatToolChoiceTypeAny) + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{fn, bash}, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStr: &anyStr, + }, + } + if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn, bash}); ok { + t.Errorf("expected no tunneling when function/custom tool is present (typed Converse path handles any)") + } +} + +// TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled — the +// caller pins web_search, which ValidateChatToolsForProvider strips on +// Bedrock. The pin name is absent from the filtered set; the helper must not +// fabricate a tunneled tool_choice for a tool that isn't in the request. +func TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled(t *testing.T) { + // The caller's original request had web_search, but it's been stripped. + // We pass the filtered slice (empty for the server-tool axis) to mimic + // the convertChatParameters call path. + params := &schemas.ChatParameters{ + Tools: []schemas.ChatTool{{Type: schemas.ChatToolType("web_search_20260209"), Name: "web_search"}}, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ + Type: schemas.ChatToolChoiceTypeFunction, + Function: &schemas.ChatToolChoiceFunction{Name: "web_search"}, + }, + }, + } + // Filtered (post-ValidateChatToolsForProvider(Bedrock)) — web_search is dropped. + filtered := []schemas.ChatTool{} + if _, ok := buildBedrockServerToolChoice(params, filtered); ok { + t.Errorf("expected no tunneling when pinned name was stripped by provider validation") + } +} + +// TestConvertChatParameters_PinnedServerToolE2E — end-to-end verification +// that convertChatParameters composes convertToolConfig + +// collectBedrockServerTools + buildBedrockServerToolChoice such that a +// request pinning a kept server tool produces: +// - AdditionalModelRequestFields.tools containing the server tool +// - AdditionalModelRequestFields.tool_choice with Anthropic-native shape +// - ToolConfig nil (no function tools → Converse's typed path is inactive) +func TestConvertChatParameters_PinnedServerToolE2E(t *testing.T) { + bifrostReq := &schemas.BifrostChatRequest{ + Model: "global.anthropic.claude-sonnet-4-6", + Params: &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("computer_20251124"), + Name: "computer", + DisplayWidthPx: schemas.Ptr(1280), + }, + }, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ + Type: schemas.ChatToolChoiceTypeFunction, + Function: &schemas.ChatToolChoiceFunction{Name: "computer"}, + }, + }, + }, + } + bedrockReq := &BedrockConverseRequest{} + if err := convertChatParameters(nil, bifrostReq, bedrockReq); err != nil { + t.Fatalf("convertChatParameters failed: %v", err) + } + if bedrockReq.ToolConfig != nil { + t.Errorf("expected nil ToolConfig (no function/custom tools), got %+v", bedrockReq.ToolConfig) + } + if bedrockReq.AdditionalModelRequestFields == nil { + t.Fatalf("expected AdditionalModelRequestFields to carry server-tool payload, got nil") + } + tools, ok := bedrockReq.AdditionalModelRequestFields.Get("tools") + if !ok { + t.Errorf("expected additionalModelRequestFields.tools to be set for server tool") + } else if toolsSlice, castOK := tools.([]json.RawMessage); !castOK || len(toolsSlice) != 1 { + t.Errorf("expected 1 server tool in additionalModelRequestFields.tools, got %+v", tools) + } + choice, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice") + if !ok { + t.Fatalf("expected additionalModelRequestFields.tool_choice to carry pinned server-tool contract") + } + choiceRaw, castOK := choice.(json.RawMessage) + if !castOK { + t.Fatalf("expected tool_choice value to be json.RawMessage, got %T", choice) + } + got := string(choiceRaw) + if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) { + t.Errorf("expected {type:tool,name:computer}, got %s", got) + } +} + +// TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice +// locks in the fix for the "two conflicting tool-choice directives" hazard: +// when response_format forces the synthetic bf_so_* tool via +// ToolConfig.ToolChoice, the tunneled additionalModelRequestFields.tool_choice +// (which would pin a server tool) must be suppressed so Bedrock doesn't +// receive both pins in the same Converse call. Uses a Nova model since +// Anthropic models route response_format through native output_config.format +// (no synthetic tool), so the conflict only surfaces on non-Anthropic +// Bedrock targets. +func TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice(t *testing.T) { + responseFormat := any(map[string]any{ + "type": "json_schema", + "json_schema": map[string]any{ + "name": "classification", + "schema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "topic": map[string]any{"type": "string"}, + }, + "required": []any{"topic"}, + }, + }, + }) + + bifrostReq := &schemas.BifrostChatRequest{ + Model: "amazon.nova-pro-v1:0", + Params: &schemas.ChatParameters{ + ResponseFormat: &responseFormat, + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolType("bash_20250124"), + Name: "bash", + }, + }, + ToolChoice: &schemas.ChatToolChoice{ + ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ + Type: schemas.ChatToolChoiceTypeFunction, + Function: &schemas.ChatToolChoiceFunction{Name: "bash"}, + }, + }, + }, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bedrockReq := &BedrockConverseRequest{} + if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil { + t.Fatalf("convertChatParameters failed: %v", err) + } + + // Synthetic bf_so_* tool must be injected and pinned via Converse's typed path. + if bedrockReq.ToolConfig == nil { + t.Fatalf("expected ToolConfig with synthetic bf_so_* tool, got nil") + } + if bedrockReq.ToolConfig.ToolChoice == nil || bedrockReq.ToolConfig.ToolChoice.Tool == nil { + t.Fatalf("expected ToolConfig.ToolChoice.Tool to pin synthetic structured-output tool, got %+v", bedrockReq.ToolConfig.ToolChoice) + } + if !strings.HasPrefix(bedrockReq.ToolConfig.ToolChoice.Tool.Name, "bf_so_") { + t.Errorf("expected ToolConfig.ToolChoice.Tool.Name to start with bf_so_, got %q", bedrockReq.ToolConfig.ToolChoice.Tool.Name) + } + + // Server tool must still be tunneled so the model has it available. + if bedrockReq.AdditionalModelRequestFields == nil { + t.Fatalf("expected AdditionalModelRequestFields to carry tunneled server-tool payload, got nil") + } + if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tools"); !ok { + t.Errorf("expected additionalModelRequestFields.tools to still carry bash server tool") + } + + // Guarded field: tunneled tool_choice MUST be absent because response_format + // forces the synthetic tool. Two tool-choice directives in the same request + // would let Bedrock pick one and silently violate the structured-output contract. + if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice"); ok { + t.Errorf("expected NO additionalModelRequestFields.tool_choice when response_format pins bf_so_* (conflict hazard)") + } +} diff --git a/core/providers/bedrock/embedding.go b/core/providers/bedrock/embedding.go index d9981a1e4d..2e2875e9cf 100644 --- a/core/providers/bedrock/embedding.go +++ b/core/providers/bedrock/embedding.go @@ -1,10 +1,10 @@ package bedrock import ( + "encoding/json" "fmt" "strings" - "github.com/maximhq/bifrost/core/providers/cohere" "github.com/maximhq/bifrost/core/schemas" ) @@ -19,11 +19,6 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) return nil, fmt.Errorf("no input text provided for embedding") } - // Validate dimensions parameter - Titan models do not support it - if bifrostReq.Params != nil && bifrostReq.Params.Dimensions != nil { - return nil, fmt.Errorf("amazon Titan embedding models do not support custom dimensions parameter") - } - titanReq := &BedrockTitanEmbeddingRequest{} // Set input text @@ -36,8 +31,26 @@ func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) } titanReq.InputText = embeddingText } + if bifrostReq.Params != nil { - titanReq.ExtraParams = bifrostReq.Params.ExtraParams + titanReq.Dimensions = bifrostReq.Params.Dimensions + if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok { + if b, ok := normalize.(bool); ok { + titanReq.Normalize = &b + } + } + // Forward remaining extra params (excluding normalize which is now a first-class field) + if len(bifrostReq.Params.ExtraParams) > 0 { + extra := make(map[string]interface{}) + for k, v := range bifrostReq.Params.ExtraParams { + if k != "normalize" { + extra[k] = v + } + } + if len(extra) > 0 { + titanReq.ExtraParams = extra + } + } } return titanReq, nil @@ -69,20 +82,81 @@ func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *sch return bifrostResponse } -// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format -// Reuses the Cohere converter since the format is identical -func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*cohere.CohereEmbeddingRequest, error) { +// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format. +// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body. +func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost embedding request is nil") } + if bifrostReq.Input == nil { + return nil, fmt.Errorf("no input provided for embedding") + } - // Reuse Cohere's converter - the format is identical for Bedrock - cohereReq := cohere.ToCohereEmbeddingRequest(bifrostReq) - if cohereReq == nil { - return nil, fmt.Errorf("failed to convert to Cohere embedding request") + req := &BedrockCohereEmbeddingRequest{} + + // Map texts + if bifrostReq.Input.Text != nil { + req.Texts = []string{*bifrostReq.Input.Text} + } else if len(bifrostReq.Input.Texts) > 0 { + req.Texts = bifrostReq.Input.Texts } - return cohereReq, nil + if bifrostReq.Params != nil { + extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams)) + for k, v := range bifrostReq.Params.ExtraParams { + extra[k] = v + } + + if v, ok := extra["input_type"]; ok { + if s, ok := v.(string); ok { + req.InputType = s + delete(extra, "input_type") + } + } + if v, ok := extra["truncate"]; ok { + if s, ok := v.(string); ok { + req.Truncate = &s + delete(extra, "truncate") + } + } + if v, ok := extra["embedding_types"]; ok { + if ss, ok := v.([]string); ok { + req.EmbeddingTypes = ss + delete(extra, "embedding_types") + } + } + if v, ok := extra["images"]; ok { + if ss, ok := v.([]string); ok { + req.Images = ss + delete(extra, "images") + } + } + if v, ok := extra["inputs"]; ok { + if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok { + req.Inputs = inputs + delete(extra, "inputs") + } + } + if v, ok := extra["max_tokens"]; ok { + switch n := v.(type) { + case int: + req.MaxTokens = &n + delete(extra, "max_tokens") + case float64: + i := int(n) + req.MaxTokens = &i + delete(extra, "max_tokens") + } + } + if bifrostReq.Params.Dimensions != nil { + req.OutputDimension = bifrostReq.Params.Dimensions + } + if len(extra) > 0 { + req.ExtraParams = extra + } + } + + return req, nil } // DetermineEmbeddingModelType determines the embedding model type from the model name @@ -96,3 +170,102 @@ func DetermineEmbeddingModelType(model string) (string, error) { return "", fmt.Errorf("unsupported embedding model: %s", model) } } + +// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format. +// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats" +// (the default, when no embedding_types are requested), and as a typed object when +// response_type is "embeddings_by_type". +func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) { + if r == nil { + return nil, fmt.Errorf("nil Bedrock Cohere embedding response") + } + + bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"} + + switch r.ResponseType { + case "embeddings_by_type": + // Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]} + var typed struct { + Float [][]float32 `json:"float"` + Base64 []string `json:"base64"` + Int8 [][]int8 `json:"int8"` + Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue + Binary [][]int8 `json:"binary"` + Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue + } + if err := json.Unmarshal(r.Embeddings, &typed); err != nil { + return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err) + } + if typed.Float != nil { + for i, emb := range typed.Float { + float64Emb := make([]float64, len(emb)) + for j, v := range emb { + float64Emb[j] = float64(v) + } + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + }) + } + } + if typed.Base64 != nil { + for i, emb := range typed.Base64 { + e := emb + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e}, + }) + } + } + for i, emb := range typed.Int8 { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, + }) + } + for i, emb := range typed.Binary { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb}, + }) + } + for i, emb := range typed.Uint8 { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, + }) + } + for i, emb := range typed.Ubinary { + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb}, + }) + } + + default: + // Default / "embeddings_floats": raw array form [[...], [...]] + var floats [][]float32 + if err := json.Unmarshal(r.Embeddings, &floats); err != nil { + return nil, fmt.Errorf("error parsing embeddings_floats: %w", err) + } + for i, emb := range floats { + float64Emb := make([]float64, len(emb)) + for j, v := range emb { + float64Emb[j] = float64(v) + } + bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{ + Object: "embedding", + Index: i, + Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb}, + }) + } + } + + return bifrostResponse, nil +} diff --git a/core/providers/bedrock/images.go b/core/providers/bedrock/images.go index 8c9ba9569c..dc3c76edd4 100644 --- a/core/providers/bedrock/images.go +++ b/core/providers/bedrock/images.go @@ -34,6 +34,61 @@ func mapQualityToBedrock(quality *string) *string { } } +// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.") +func isStabilityAIModel(model string) bool { + return strings.Contains(strings.ToLower(model), "stability.") +} + +// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat +// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models. +// Stability AI is excluded here — it's handled separately because it also supports image edit. +func isPromptOnlyImageGenerationModel(model string) bool { + m := strings.ToLower(model) + return strings.Contains(m, "image") +} + +// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI +// flat request format used by Bedrock (stability.stable-image-* models). +func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) { + if request == nil { + return nil, fmt.Errorf("request is nil") + } + if request.Input == nil { + return nil, fmt.Errorf("request input is required") + } + + req := &StabilityAIImageGenerationRequest{ + Prompt: request.Input.Prompt, + } + + if request.Params != nil { + if request.Params.AspectRatio != nil { + req.AspectRatio = request.Params.AspectRatio + } + if request.Params.OutputFormat != nil { + req.OutputFormat = request.Params.OutputFormat + } + if request.Params.Seed != nil { + req.Seed = request.Params.Seed + } + if request.Params.NegativePrompt != nil { + req.NegativePrompt = request.Params.NegativePrompt + } + if request.Params.ExtraParams != nil { + // aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set + if req.AspectRatio == nil { + if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok { + delete(request.Params.ExtraParams, "aspect_ratio") + req.AspectRatio = ar + } + } + req.ExtraParams = request.Params.ExtraParams + } + } + + return req, nil +} + // ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) { if request == nil { @@ -41,7 +96,7 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } if request.Input == nil { - return nil, fmt.Errorf("request.Input is required") + return nil, fmt.Errorf("request input is required") } bedrockReq := &BedrockImageGenerationRequest{ @@ -98,7 +153,24 @@ func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequ } return bedrockReq, nil +} +// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to +// the native Bedrock invoke API response format used by Stability AI models. +// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas. +func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) { + if response == nil { + return nil, fmt.Errorf("response is nil") + } + result := &BedrockImageGenerationResponse{} + for _, d := range response.Data { + result.Images = append(result.Images, d.B64JSON) + } + if response.ImageGenerationResponseParameters != nil { + result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons + result.Seeds = response.ImageGenerationResponseParameters.Seeds + } + return result, nil } // ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request @@ -358,6 +430,292 @@ func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGener return config } +// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value +// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type. +func getStabilityAITaskTypeFromParams(t string) string { + switch strings.ToLower(t) { + case "inpainting", "inpaint": + return "inpaint" + case "outpainting", "outpaint": + return "outpaint" + case "background_removal", "remove_background": + return "remove-bg" + case "erase_object": + return "erase-object" + case "upscale_fast": + return "upscale-fast" + case "upscale_creative": + return "upscale-creative" + case "upscale_conservative": + return "upscale-conservative" + case "recolor": + return "recolor" + case "search_replace": + return "search-replace" + case "control_sketch": + return "control-sketch" + case "control_structure": + return "control-structure" + case "style_guide": + return "style-guide" + case "style_transfer": + return "style-transfer" + default: + return "" + } +} + +// getStabilityAIEditTaskType infers the Stability AI edit task from the model name. +// Returns an error if the model name does not match any known pattern. +func getStabilityAIEditTaskType(model string) (string, error) { + m := strings.ToLower(model) + switch { + case strings.Contains(m, "stable-creative-upscale"): + return "upscale-creative", nil + case strings.Contains(m, "stable-conservative-upscale"): + return "upscale-conservative", nil + case strings.Contains(m, "stable-fast-upscale"): + return "upscale-fast", nil + case strings.Contains(m, "stable-image-inpaint"): + return "inpaint", nil + case strings.Contains(m, "stable-outpaint"): + return "outpaint", nil + case strings.Contains(m, "stable-image-search-recolor"): + return "recolor", nil + case strings.Contains(m, "stable-image-search-replace"): + return "search-replace", nil + case strings.Contains(m, "stable-image-erase-object"): + return "erase-object", nil + case strings.Contains(m, "stable-image-remove-background"): + return "remove-bg", nil + case strings.Contains(m, "stable-image-control-sketch"): + return "control-sketch", nil + case strings.Contains(m, "stable-image-control-structure"): + return "control-structure", nil + case strings.Contains(m, "stable-image-style-guide"): + return "style-guide", nil + case strings.Contains(m, "stable-style-transfer"): + return "style-transfer", nil + default: + return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model) + } +} + +// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request +// format used by Bedrock edit models. Only fields valid for the detected task type are populated. +// deployment is the resolved model identifier (after applying any deployment alias mapping); it is +// used for task-type inference so that alias-mapped models route correctly. +func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) { + if request == nil || request.Input == nil { + return nil, fmt.Errorf("request or input is nil") + } + + var taskType string + if request.Params != nil && request.Params.Type != nil { + taskType = getStabilityAITaskTypeFromParams(*request.Params.Type) + } + if taskType == "" { + var err error + taskType, err = getStabilityAIEditTaskType(deployment) + if err != nil { + return nil, err + } + } + + req := &StabilityAIImageEditRequest{} + + // Image sourcing + if taskType == "style-transfer" { + if len(request.Input.Images) != 2 { + return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image") + } + if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 { + return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image") + } + initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image) + styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image) + req.InitImage = &initB64 + req.StyleImage = &styleB64 + } else { + if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 { + return nil, fmt.Errorf("at least one image is required") + } + imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image) + req.Image = &imageB64 + } + + // Common fields populated based on task allowlist + prompt := request.Input.Prompt + switch taskType { + case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure", + "style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer": + req.Prompt = &prompt + } + + // Negative prompt + if request.Params != nil && request.Params.NegativePrompt != nil { + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch", + "control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer": + req.NegativePrompt = request.Params.NegativePrompt + } + } + + // Seed + if request.Params != nil && request.Params.Seed != nil { + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch", + "control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer": + req.Seed = request.Params.Seed + } + } + + // Mask (from Params.Mask bytes) + if request.Params != nil && len(request.Params.Mask) > 0 { + switch taskType { + case "inpaint", "erase-object": + maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask) + req.Mask = &maskB64 + } + } + + // ExtraParams + if request.Params != nil { + // Typed OutputFormat takes priority over ExtraParams + if request.Params.OutputFormat != nil { + req.OutputFormat = request.Params.OutputFormat + } + + if request.Params.ExtraParams != nil { + ep := make(map[string]interface{}, len(request.Params.ExtraParams)) + for k, v := range request.Params.ExtraParams { + ep[k] = v + } + + // output_format — all tasks (fallback if not already set by typed field) + if req.OutputFormat == nil { + if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok { + delete(ep, "output_format") + req.OutputFormat = v + } + } + + // style_preset + switch taskType { + case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch", + "control-structure", "style-guide", "upscale-creative": + if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok { + delete(ep, "style_preset") + req.StylePreset = v + } + } + + // grow_mask + switch taskType { + case "inpaint", "recolor", "search-replace", "erase-object": + if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok { + delete(ep, "grow_mask") + req.GrowMask = v + } + } + + // outpaint directional fields + if taskType == "outpaint" { + if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok { + delete(ep, "left") + req.Left = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok { + delete(ep, "right") + req.Right = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok { + delete(ep, "up") + req.Up = v + } + if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok { + delete(ep, "down") + req.Down = v + } + } + + // creativity + switch taskType { + case "upscale-creative", "upscale-conservative", "outpaint": + if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok { + delete(ep, "creativity") + req.Creativity = v + } + } + + // select_prompt (recolor) + if taskType == "recolor" { + if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok { + delete(ep, "select_prompt") + req.SelectPrompt = v + } + } + + // search_prompt (search-replace) + if taskType == "search-replace" { + if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok { + delete(ep, "search_prompt") + req.SearchPrompt = v + } + } + + // control_strength + switch taskType { + case "control-sketch", "control-structure": + if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok { + delete(ep, "control_strength") + req.ControlStrength = v + } + } + + // style-guide fields + if taskType == "style-guide" { + if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok { + delete(ep, "aspect_ratio") + req.AspectRatio = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok { + delete(ep, "fidelity") + req.Fidelity = v + } + } + + // style-transfer fields + if taskType == "style-transfer" { + if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok { + delete(ep, "style_strength") + req.StyleStrength = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok { + delete(ep, "composition_fidelity") + req.CompositionFidelity = v + } + if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok { + delete(ep, "change_strength") + req.ChangeStrength = v + } + } + + req.ExtraParams = ep + } + } + + // Validate required per-task fields + if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") { + return nil, fmt.Errorf("select_prompt is required for stability ai recolor task") + } + if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") { + return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task") + } + + return req, nil +} + // ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse { if response == nil { @@ -366,6 +724,13 @@ func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) bifrostResponse := &schemas.BifrostImageGenerationResponse{} + if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 { + bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{ + FinishReasons: append([]*string(nil), response.FinishReasons...), + Seeds: append([]int(nil), response.Seeds...), + } + } + for index, image := range response.Images { bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{ B64JSON: image, diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index bbd090a214..8227e8639a 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -2,8 +2,10 @@ package bedrock import ( "bytes" + "encoding/base64" "encoding/json" "fmt" + "net/url" "strings" "github.com/bytedance/sonic" @@ -44,6 +46,17 @@ var bedrockInvokeRequestKnownFields = map[string]bool{ "message": true, "chat_history": true, // AI21 "n": true, "frequency_penalty": true, "presence_penalty": true, + // Bedrock image gen / edit / variation (Titan/Nova Canvas) + "taskType": true, "textToImageParams": true, "imageVariationParams": true, + "inPaintingParams": true, "outPaintingParams": true, "backgroundRemovalParams": true, + "imageGenerationConfig": true, + // Stability AI image + "image": true, "mask": true, "negative_prompt": true, + "aspect_ratio": true, "output_format": true, "seed": true, + // Embeddings + "inputText": true, "texts": true, "input_type": true, + "normalize": true, "dimensions": true, + "embedding_types": true, "output_dimension": true, "inputs": true, // Internal "stream": true, "extra_params": true, } @@ -125,17 +138,74 @@ func (r *BedrockInvokeRequest) UnmarshalJSON(data []byte) error { return nil } -// DetectInvokeRequestType determines the request type from raw JSON body -// without full deserialization, keeping detection logic colocated with IsMessagesRequest. -func DetectInvokeRequestType(body []byte) schemas.RequestType { - node, _ := sonic.Get(body, "messages") - if node.Exists() { - raw, err := node.Raw() - if err == nil && raw != "null" && raw != "[]" { +// DetectInvokeRequestType determines the request type from raw JSON body and model ID +// without full deserialization, keeping detection logic colocated with conversion methods. +func DetectInvokeRequestType(body []byte, modelID string) schemas.RequestType { + // Messages → chat/responses path + if node, _ := sonic.Get(body, "messages"); node.Exists() { + if raw, err := node.Raw(); err == nil && raw != "null" && raw != "[]" { return schemas.ResponsesRequest } + } + + // Titan uses "inputText" for both embeddings and text generation. + // Use the model ID to disambiguate: embedding models contain "embed". + if node, _ := sonic.Get(body, "inputText"); node.Exists() { + if strings.Contains(strings.ToLower(modelID), "embed") { + return schemas.EmbeddingRequest + } return schemas.TextCompletionRequest } + + // Cohere embedding: text-only (texts), image-only (images), or mixed (inputs). + // Use model ID to identify embed models, then check for any non-empty payload field. + if strings.Contains(strings.ToLower(modelID), "embed") { + for _, field := range []string{"texts", "images", "inputs"} { + if node, _ := sonic.Get(body, field); node.Exists() { + if raw, err := node.Raw(); err == nil && raw != "null" && raw != "[]" { + return schemas.EmbeddingRequest + } + } + } + } + + // taskType-based image routing + if taskNode, _ := sonic.Get(body, "taskType"); taskNode.Exists() { + taskType, _ := taskNode.String() + switch taskType { + case TaskTypeTextImage: + return schemas.ImageGenerationRequest + case TaskTypeImageVariation: + return schemas.ImageVariationRequest + case TaskTypeInpainting, TaskTypeOutpainting, TaskTypeBackgroundRemoval: + return schemas.ImageEditRequest + } + } + + // URL-decode the model ID once for all model-name checks below + decodedModelID := modelID + if unescaped, err := url.PathUnescape(modelID); err == nil { + decodedModelID = unescaped + } + + // Stability AI: supports both generation (prompt-only) and edit (image+prompt) + if isStabilityAIModel(decodedModelID) { + if node, _ := sonic.Get(body, "image"); node.Exists() { + return schemas.ImageEditRequest + } + return schemas.ImageGenerationRequest + } + + // explicit image field -> edit request + if node, _ := sonic.Get(body, "image"); node.Exists() { + return schemas.ImageEditRequest + } + + // Checked after all body-field and model-specific signals so it doesn't shadow known models. + if isPromptOnlyImageGenerationModel(decodedModelID) { + return schemas.ImageGenerationRequest + } + return schemas.TextCompletionRequest } @@ -310,6 +380,382 @@ func (r *BedrockInvokeRequest) ToBifrostTextCompletionRequest(ctx *schemas.Bifro return textReq.ToBifrostTextCompletionRequest(ctx) } +// ToBifrostEmbeddingRequest converts the invoke request to a BifrostEmbeddingRequest. +// Handles both Titan (inputText) and Cohere (texts) embedding formats. +func (r *BedrockInvokeRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostEmbeddingRequest{ + Provider: provider, + Model: model, + } + + if r.InputText != "" { + req.Input = &schemas.EmbeddingInput{Text: &r.InputText} + } else if len(r.Texts) > 0 { + req.Input = &schemas.EmbeddingInput{Texts: r.Texts} + } + // image-only (r.Images) or mixed (r.Inputs): req.Input stays nil; data flows via ExtraParams + + extraParams := make(map[string]interface{}) + // Forward known embedding-only params into ExtraParams so the provider can pick them up + if r.InputType != nil { + extraParams["input_type"] = *r.InputType + } + if r.Normalize != nil { + extraParams["normalize"] = *r.Normalize + } + if len(r.EmbeddingTypes) > 0 { + extraParams["embedding_types"] = r.EmbeddingTypes + } + if r.Truncate != nil { + extraParams["truncate"] = *r.Truncate + } + if len(r.Images) > 0 { + extraParams["images"] = r.Images + } + if len(r.Inputs) > 0 { + extraParams["inputs"] = r.Inputs + } + if r.MaxTokens != nil { + extraParams["max_tokens"] = *r.MaxTokens + } + // Merge any remaining extra params from the request + for k, v := range r.ExtraParams { + extraParams[k] = v + } + + // output_dimension maps to Dimensions; prefer OutputDimension over Dimensions + dimensions := r.Dimensions + if r.OutputDimension != nil { + dimensions = r.OutputDimension + } + params := &schemas.EmbeddingParameters{ + Dimensions: dimensions, + } + if len(extraParams) > 0 { + params.ExtraParams = extraParams + } + req.Params = params + + return req +} + +// ToBifrostImageGenerationRequest converts the invoke request to a BifrostImageGenerationRequest. +// Handles Titan/Nova Canvas (taskType=TEXT_IMAGE with textToImageParams) and Stability AI (flat prompt fields). +func (r *BedrockInvokeRequest) ToBifrostImageGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageGenerationRequest { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageGenerationRequest{ + Provider: provider, + Model: model, + } + + params := &schemas.ImageGenerationParameters{ + NegativePrompt: r.NegativePrompt, + AspectRatio: r.AspectRatio, + N: r.N, + OutputFormat: r.OutputFormat, + Seed: r.Seed, + } + + if r.TextToImageParams != nil { + // Titan / Nova Canvas path + req.Input = &schemas.ImageGenerationInput{Prompt: r.TextToImageParams.Text} + if r.TextToImageParams.NegativeText != nil { + params.NegativePrompt = r.TextToImageParams.NegativeText + } + if r.TextToImageParams.Style != nil { + params.Style = r.TextToImageParams.Style + } + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + params.Seed = cfg.Seed + params.Quality = cfg.Quality + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.CfgScale != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + params.ExtraParams["cfgScale"] = *cfg.CfgScale + } + } + } else { + // Stability AI path — prompt comes from the top-level "prompt" field + req.Input = &schemas.ImageGenerationInput{Prompt: r.Prompt} + } + + // Forward any remaining ExtraParams + if len(r.ExtraParams) > 0 { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + for k, v := range r.ExtraParams { + params.ExtraParams[k] = v + } + } + + req.Params = params + return req +} + +// ToBifrostImageEditRequest converts the invoke request to a BifrostImageEditRequest. +// Handles Titan/Nova Canvas (taskType in INPAINTING/OUTPAINTING/BACKGROUND_REMOVAL) and Stability AI (flat image/mask fields). +func (r *BedrockInvokeRequest) ToBifrostImageEditRequest(ctx *schemas.BifrostContext) (*schemas.BifrostImageEditRequest, error) { + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageEditRequest{ + Provider: provider, + Model: model, + } + params := &schemas.ImageEditParameters{ + NegativePrompt: r.NegativePrompt, + Seed: r.Seed, + } + + if r.TaskType != nil { + // Titan / Nova Canvas path + switch *r.TaskType { + case TaskTypeInpainting: + if r.InPaintingParams == nil { + return nil, fmt.Errorf("inPaintingParams required for INPAINTING task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.InPaintingParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode inpainting image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.InPaintingParams.Text, + } + params.Type = schemas.Ptr("inpainting") + if r.InPaintingParams.NegativeText != nil { + params.NegativePrompt = r.InPaintingParams.NegativeText + } + if r.InPaintingParams.MaskImage != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.InPaintingParams.MaskImage) + if err != nil { + return nil, fmt.Errorf("failed to decode inpainting mask: %w", err) + } + params.Mask = maskBytes + } + if r.InPaintingParams.MaskPrompt != nil || r.InPaintingParams.ReturnMask != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + if r.InPaintingParams.MaskPrompt != nil { + params.ExtraParams["mask_prompt"] = *r.InPaintingParams.MaskPrompt + } + if r.InPaintingParams.ReturnMask != nil { + params.ExtraParams["return_mask"] = *r.InPaintingParams.ReturnMask + } + } + + case TaskTypeOutpainting: + if r.OutPaintingParams == nil { + return nil, fmt.Errorf("outPaintingParams required for OUTPAINTING task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.OutPaintingParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode outpainting image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.OutPaintingParams.Text, + } + params.Type = schemas.Ptr("outpainting") + if r.OutPaintingParams.NegativeText != nil { + params.NegativePrompt = r.OutPaintingParams.NegativeText + } + if r.OutPaintingParams.MaskImage != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.OutPaintingParams.MaskImage) + if err != nil { + return nil, fmt.Errorf("failed to decode outpainting mask: %w", err) + } + params.Mask = maskBytes + } + if r.OutPaintingParams.MaskPrompt != nil || r.OutPaintingParams.ReturnMask != nil || r.OutPaintingParams.OutPaintingMode != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + if r.OutPaintingParams.MaskPrompt != nil { + params.ExtraParams["mask_prompt"] = *r.OutPaintingParams.MaskPrompt + } + if r.OutPaintingParams.ReturnMask != nil { + params.ExtraParams["return_mask"] = *r.OutPaintingParams.ReturnMask + } + if r.OutPaintingParams.OutPaintingMode != nil { + params.ExtraParams["outpainting_mode"] = *r.OutPaintingParams.OutPaintingMode + } + } + + case TaskTypeBackgroundRemoval: + if r.BackgroundRemovalParams == nil { + return nil, fmt.Errorf("backgroundRemovalParams required for BACKGROUND_REMOVAL task") + } + imgBytes, err := base64.StdEncoding.DecodeString(r.BackgroundRemovalParams.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode background removal image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + } + params.Type = schemas.Ptr("background_removal") + + default: + return nil, fmt.Errorf("unsupported taskType for image edit: %s", *r.TaskType) + } + + // Map imageGenerationConfig fields into edit params (Titan/Nova Canvas only) + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + params.Seed = cfg.Seed + params.Quality = cfg.Quality + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.CfgScale != nil { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}) + } + params.ExtraParams["cfgScale"] = *cfg.CfgScale + } + } + } else { + // Stability AI path + if r.Image == nil { + return nil, fmt.Errorf("image field is required for Stability AI image edit") + } + imgBytes, err := base64.StdEncoding.DecodeString(*r.Image) + if err != nil { + return nil, fmt.Errorf("failed to decode stability AI image: %w", err) + } + req.Input = &schemas.ImageEditInput{ + Images: []schemas.ImageInput{{Image: imgBytes}}, + Prompt: r.Prompt, + } + // Infer task type from model name + taskType, err := getStabilityAIEditTaskType(r.ModelID) + if err != nil { + return nil, fmt.Errorf("cannot determine Stability AI edit task: %w", err) + } + params.Type = &taskType + if r.Mask != nil { + maskBytes, err := base64.StdEncoding.DecodeString(*r.Mask) + if err != nil { + return nil, fmt.Errorf("failed to decode stability AI mask: %w", err) + } + params.Mask = maskBytes + } + } + + if len(r.ExtraParams) > 0 { + if params.ExtraParams == nil { + params.ExtraParams = make(map[string]interface{}, len(r.ExtraParams)) + } + for k, v := range r.ExtraParams { + params.ExtraParams[k] = v + } + } + req.Params = params + return req, nil +} + +// ToBifrostImageVariationRequest converts the invoke request to a BifrostImageVariationRequest. +// Reads from imageVariationParams (Titan/Nova Canvas format). +func (r *BedrockInvokeRequest) ToBifrostImageVariationRequest(ctx *schemas.BifrostContext) (*schemas.BifrostImageVariationRequest, error) { + if r.ImageVariationParams == nil || len(r.ImageVariationParams.Images) == 0 { + return nil, fmt.Errorf("imageVariationParams.images is required for IMAGE_VARIATION") + } + + primaryBytes, err := base64.StdEncoding.DecodeString(r.ImageVariationParams.Images[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode primary variation image: %w", err) + } + + modelID := r.ModelID + if unescaped, err := url.PathUnescape(r.ModelID); err == nil { + modelID = unescaped + } + provider, model := schemas.ParseModelString(modelID, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock)) + req := &schemas.BifrostImageVariationRequest{ + Provider: provider, + Model: model, + Input: &schemas.ImageVariationInput{ + Image: schemas.ImageInput{Image: primaryBytes}, + }, + } + + params := &schemas.ImageVariationParameters{} + extraParams := make(map[string]interface{}) + + // Additional images (index 1+) stored under "images" key for the provider + if len(r.ImageVariationParams.Images) > 1 { + additionalImages := make([][]byte, 0, len(r.ImageVariationParams.Images)-1) + for _, imgB64 := range r.ImageVariationParams.Images[1:] { + imgBytes, err := base64.StdEncoding.DecodeString(imgB64) + if err != nil { + return nil, fmt.Errorf("failed to decode additional variation image: %w", err) + } + additionalImages = append(additionalImages, imgBytes) + } + extraParams["images"] = additionalImages + } + + // Text / negative text / similarity strength go to ExtraParams (provider reads them from there) + if r.ImageVariationParams.Text != nil { + extraParams["prompt"] = *r.ImageVariationParams.Text + } + if r.ImageVariationParams.NegativeText != nil { + extraParams["negativeText"] = *r.ImageVariationParams.NegativeText + } + if r.ImageVariationParams.SimilarityStrength != nil { + extraParams["similarityStrength"] = *r.ImageVariationParams.SimilarityStrength + } + + // ImageGenerationConfig → N, Size, Seed, Quality, CfgScale + if cfg := r.ImageGenerationConfig; cfg != nil { + params.N = cfg.NumberOfImages + if cfg.Width != nil && cfg.Height != nil { + size := fmt.Sprintf("%dx%d", *cfg.Width, *cfg.Height) + params.Size = &size + } + if cfg.Seed != nil { + extraParams["seed"] = *cfg.Seed + } + if cfg.Quality != nil { + extraParams["quality"] = *cfg.Quality + } + if cfg.CfgScale != nil { + extraParams["cfgScale"] = *cfg.CfgScale + } + } + + // Forward any remaining ExtraParams from the request body + for k, v := range r.ExtraParams { + extraParams[k] = v + } + if len(extraParams) > 0 { + params.ExtraParams = extraParams + } + + req.Params = params + return req, nil +} + // buildCohereCommandRPrompt converts Cohere Command R's message + chat_history into a text prompt. func (r *BedrockInvokeRequest) buildCohereCommandRPrompt() string { var sb strings.Builder @@ -448,9 +894,16 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return nil, fmt.Errorf("bifrost response is nil") } - model := resp.Model - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested + model := "" + if resp.Model != "" { + model = resp.Model + } else { + extraFields := resp.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } } // Nova models: delegate to existing ToBedrockConverseResponse (Nova InvokeModel matches Converse format) @@ -467,6 +920,101 @@ func ToBedrockInvokeMessagesResponse(ctx *schemas.BifrostContext, resp *schemas. return toBedrockInvokeAnthropicResponse(resp, model), nil } +func ToBedrockInvokeImagesResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostImageGenerationResponse) (interface{}, error) { + if resp == nil { + return nil, fmt.Errorf("bifrost response is nil") + } + + // If the provider stored the raw Bedrock response, return it verbatim (preserves seeds, finish_reasons, etc.) + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + + model := resp.Model + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } + } + + // Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas + if isStabilityAIModel(model) { + return ToStabilityAIImageGenerationResponse(resp) + } + + // Default: Titan Image Generator v1/v2, Nova Canvas — reconstruct from Bifrost data + result := &BedrockImageGenerationResponse{} + for _, d := range resp.Data { + result.Images = append(result.Images, d.B64JSON) + } + return result, nil +} + +// ToBedrockEmbeddingInvokeResponse converts a BifrostEmbeddingResponse back to the native +// Bedrock invoke API response format. +// Single-embedding (Titan) responses use: {"embedding": [...], "inputTextTokenCount": N} +// Multi-embedding (Cohere) responses use: {"embeddings": [[...],[...]], "response_type": "embeddings_floats"} +func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (interface{}, error) { + if resp == nil { + return nil, fmt.Errorf("bifrost embedding response is nil") + } + + // If the provider stored the raw Bedrock response, return it verbatim + if resp.ExtraFields.RawResponse != nil { + return resp.ExtraFields.RawResponse, nil + } + + tokenCount := 0 + if resp.Usage != nil { + tokenCount = resp.Usage.PromptTokens + } + + if len(resp.Data) == 0 { + return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil + } + + // Use model name to distinguish Cohere from Titan — not batch size. + // A single-input Cohere request must still return the Cohere envelope format. + model := resp.Model + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } + } + + if strings.Contains(strings.ToLower(model), "cohere") { + floats := make([][]float32, 0, len(resp.Data)) + for _, d := range resp.Data { + float32Emb := make([]float32, len(d.Embedding.EmbeddingArray)) + for i, v := range d.Embedding.EmbeddingArray { + float32Emb[i] = float32(v) + } + floats = append(floats, float32Emb) + } + return &BedrockInvokeCohereEmbeddingResp{ + Embeddings: floats, + ResponseType: "embeddings_floats", + }, nil + } + + // Titan format + if resp.Data[0].Embedding.EmbeddingArray == nil { + return &BedrockInvokeEmbeddingResp{InputTextTokenCount: tokenCount}, nil + } + float32Emb := make([]float32, len(resp.Data[0].Embedding.EmbeddingArray)) + for i, v := range resp.Data[0].Embedding.EmbeddingArray { + float32Emb[i] = float32(v) + } + return &BedrockInvokeEmbeddingResp{ + Embedding: float32Emb, + InputTextTokenCount: tokenCount, + }, nil +} + // toBedrockInvokeAnthropicResponse converts BifrostResponsesResponse to Anthropic Messages API format. func toBedrockInvokeAnthropicResponse(resp *schemas.BifrostResponsesResponse, model string) *BedrockInvokeMessagesResponse { result := &BedrockInvokeMessagesResponse{ @@ -623,12 +1171,17 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc // final Completed event). Without checking resp.ExtraFields, early chunks would // have model="" and Nova streams would be mis-routed through the Anthropic path. model := "" - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.ExtraFields.ModelRequested != "" { - model = resp.Response.ExtraFields.ModelRequested - } else if resp.Response != nil && resp.Response.Model != "" { - model = resp.Response.Model + if resp.Response != nil { + if resp.Response.Model != "" { + model = resp.Response.Model + } else { + extraFields := resp.Response.ExtraFields + if extraFields.ResolvedModelUsed != "" { + model = extraFields.ResolvedModelUsed + } else if extraFields.OriginalModelRequested != "" { + model = extraFields.OriginalModelRequested + } + } } // Nova models: delegate to existing converse stream response (same format) @@ -657,6 +1210,7 @@ func ToBedrockInvokeMessagesStreamResponse(ctx *schemas.BifrostContext, resp *sc bedrockEvent := &BedrockStreamEvent{ InvokeModelRawChunks: rawChunks, } + return "", bedrockEvent, nil } @@ -669,8 +1223,11 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) switch resp.Type { case schemas.ResponsesStreamResponseTypeCreated: - // message_start — use ExtraFields.ModelRequested as fallback for early chunks - model := resp.ExtraFields.ModelRequested + // message_start — prefer resolved model for accurate family detection on early chunks + model := resp.ExtraFields.ResolvedModelUsed + if model == "" { + model = resp.ExtraFields.OriginalModelRequested + } msgStart := map[string]interface{}{ "type": "message_start", "message": map[string]interface{}{ @@ -780,7 +1337,7 @@ func toAnthropicInvokeStreamBytes(resp *schemas.BifrostResponsesStreamResponse) "type": "content_block_delta", "index": idx, "delta": map[string]interface{}{ - "type": "input_json_delta", + "type": "input_json_delta", "partial_json": *resp.Delta, }, } diff --git a/core/providers/bedrock/models.go b/core/providers/bedrock/models.go index e4e96a8017..549db2e3bd 100644 --- a/core/providers/bedrock/models.go +++ b/core/providers/bedrock/models.go @@ -1,7 +1,6 @@ package bedrock import ( - "slices" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -82,147 +81,7 @@ type BedrockRerankResponseDocument struct { TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"` } -// regionPrefixes is a list of region prefixes used in Bedrock deployments -// Based on AWS region naming patterns and Bedrock deployment configurations -var regionPrefixes = []string{ - "us.", // US regions (us-east-1, us-west-2, etc.) - "eu.", // Europe regions (eu-west-1, eu-central-1, etc.) - "ap.", // Asia Pacific regions (ap-southeast-1, ap-northeast-1, etc.) - "ca.", // Canada regions (ca-central-1, etc.) - "sa.", // South America regions (sa-east-1, etc.) - "af.", // Africa regions (af-south-1, etc.) - "global.", // Global deployment prefix -} - -// extractPrefix extracts the region prefix ending with '.' from a string -// Only recognizes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the prefix (including the dot) if found, empty string otherwise -func extractPrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return prefix - } - } - return "" -} - -// removePrefix removes any region prefix ending with '.' from a string -// Only removes common region prefixes like "us.", "global.", "eu.", etc. -// Returns the string without the prefix -func removePrefix(s string) string { - for _, prefix := range regionPrefixes { - if strings.HasPrefix(s, prefix) { - return s[len(prefix):] - } - } - return s -} - -// findMatchingAllowedModel finds a matching item in a slice, considering both -// exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// Returns the matched item from the slice if found, empty string otherwise. -// If matched via base model, returns the item from slice (not the value parameter). -func findMatchingAllowedModel(slice []string, value string) string { - // First check exact matches - if slices.Contains(slice, value) { - return value - } - - // Check with region prefix added/removed - valuePrefix := extractPrefix(value) - if valuePrefix != "" { - // value has a prefix, check if slice contains version without prefix - withoutPrefix := removePrefix(value) - if slices.Contains(slice, withoutPrefix) { - return withoutPrefix - } - } - - // Check if any item in slice has a prefix that matches value without prefix - for _, item := range slice { - itemPrefix := extractPrefix(item) - if itemPrefix != "" { - // item has prefix, check if value matches without the prefix - itemWithoutPrefix := removePrefix(item) - if itemWithoutPrefix == value { - return item - } - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize value by removing any region prefix for base model comparison - valueNormalized := removePrefix(value) - - for _, item := range slice { - // Normalize item by removing any region prefix for base model comparison - itemNormalized := removePrefix(item) - - // Check base model match with normalized values (prefix removed from both) - // Return the item from slice (not value) to use the actual name from allowedModels - if schemas.SameBaseModel(itemNormalized, valueNormalized) { - return item - } - } - return "" -} - -// findDeploymentMatch finds a matching deployment value in the deployments map, -// considering both exact match and match with/without region prefixes (e.g., "global.", "us.", "eu."), -// and also checks base model matches (ignoring version suffixes). -// The modelID from the API response should match a deployment value (not the alias/key). -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, modelID string) (deploymentValue, alias string) { - // Check if any deployment value matches the modelID (with or without prefix) - for aliasKey, deploymentValue := range deployments { - // Exact match - if deploymentValue == modelID || aliasKey == modelID { - return deploymentValue, aliasKey - } - - // Check prefix variations - deploymentPrefix := extractPrefix(deploymentValue) - modelIDPrefix := extractPrefix(modelID) - aliasKeyPrefix := extractPrefix(aliasKey) - - // Case 1: deploymentValue or aliasKey has prefix, modelID doesn't - if (deploymentPrefix != "" && modelIDPrefix == "") || (aliasKeyPrefix != "" && modelIDPrefix == "") { - if removePrefix(deploymentValue) == modelID || removePrefix(aliasKey) == modelID { - return deploymentValue, aliasKey - } - } - - // Case 2: modelID or aliasKey has prefix, deploymentValue doesn't - if (modelIDPrefix != "" && deploymentPrefix == "") || (aliasKeyPrefix != "" && deploymentPrefix == "") { - if removePrefix(modelID) == deploymentValue || removePrefix(modelID) == aliasKey { - return deploymentValue, aliasKey - } - } - - // Case 3: Both have prefixes but different prefixes - if (deploymentPrefix != "" && modelIDPrefix != "" && deploymentPrefix != modelIDPrefix) || (aliasKeyPrefix != "" && modelIDPrefix != "" && aliasKeyPrefix != modelIDPrefix) { - if removePrefix(deploymentValue) == removePrefix(modelID) || removePrefix(aliasKey) == removePrefix(modelID) { - return deploymentValue, aliasKey - } - } - - // Additional layer: check base model matches (ignoring version suffixes) - // This handles cases where model versions differ but base model is the same - // Normalize both values by removing any region prefix for base model comparison - deploymentNormalized := removePrefix(deploymentValue) - modelIDNormalized := removePrefix(modelID) - - // Check base model match with normalized values (prefix removed from both) - if schemas.SameBaseModel(deploymentNormalized, modelIDNormalized) { - return deploymentValue, aliasKey - } - } - return "", "" -} - -func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -231,121 +90,41 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK Data: make([]schemas.Model, 0, len(response.ModelSummaries)), } - deploymentValues := make([]string, 0, len(deployments)) - for _, deployment := range deployments { - deploymentValues = append(deploymentValues, deployment) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), } - - includedModels := make(map[string]bool) - for _, model := range response.ModelSummaries { - modelID := model.ModelID - matchedAllowedModel := "" - deploymentValue := "" - deploymentAlias := "" - - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - // Check considering global prefix variations - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deploymentValues) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) - } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - matchedAllowedModel = findMatchingAllowedModel(allowedModels, model.ModelID) - shouldFilter = matchedAllowedModel == "" - } else if !unfiltered && len(deploymentValues) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, model.ModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - // Use the matched name from allowedModels or deployments (like Anthropic) - // Priority: deployment value > matched allowedModel > original model.ModelID - if deploymentValue != "" { - modelID = deploymentValue - } else if matchedAllowedModel != "" { - modelID = matchedAllowedModel - } - - if !unfiltered && providerUtils.ModelMatchesDenylist(blacklistedModels, model.ModelID, modelID, deploymentAlias, matchedAllowedModel) { - continue - } - - modelEntry := schemas.Model{ - ID: string(providerKey) + "/" + modelID, - Name: schemas.Ptr(model.ModelName), - OwnedBy: schemas.Ptr(model.ProviderName), - Architecture: &schemas.Architecture{ - InputModalities: model.InputModalities, - OutputModalities: model.OutputModalities, - }, - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(providerKey) + "/" + deploymentAlias - // Use the actual deployment value (which might have global prefix) - modelEntry.Deployment = schemas.Ptr(deploymentValue) - includedModels[deploymentAlias] = true - } else { - includedModels[modelID] = true - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - if includedModels[alias] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if providerUtils.ModelMatchesDenylist(blacklistedModels, alias) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + alias, - Name: schemas.Ptr(alias), - Deployment: schemas.Ptr(deploymentValue), - }) - includedModels[alias] = true - } - } + included := make(map[string]bool) - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if providerUtils.ModelMatchesDenylist(blacklistedModels, allowedModel) { - continue + for _, model := range response.ModelSummaries { + for _, result := range pipeline.FilterModel(model.ModelID) { + modelEntry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.ModelName), + OwnedBy: schemas.Ptr(model.ProviderName), + Architecture: &schemas.Architecture{ + InputModalities: model.InputModalities, + OutputModalities: model.OutputModalities, + }, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/bedrock/rerank_test.go b/core/providers/bedrock/rerank_test.go index 0dff5c3ee2..c1b7bb5480 100644 --- a/core/providers/bedrock/rerank_test.go +++ b/core/providers/bedrock/rerank_test.go @@ -195,27 +195,23 @@ func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) { func TestResolveBedrockDeployment(t *testing.T) { key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", }, } - deployment := resolveBedrockDeployment("cohere-rerank", key) + deployment := key.Aliases.Resolve("cohere-rerank") assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment) - assert.Equal(t, "cohere.rerank-v3-5:0", resolveBedrockDeployment("cohere.rerank-v3-5:0", key)) - assert.Equal(t, "", resolveBedrockDeployment("", key)) + assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0")) + assert.Equal(t, "", key.Aliases.Resolve("")) } func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) { provider := &BedrockProvider{} ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) key := schemas.Key{ - BedrockKeyConfig: &schemas.BedrockKeyConfig{ - Deployments: map[string]string{ - "cohere-rerank": "cohere.rerank-v3-5:0", - }, + Aliases: schemas.KeyAliases{ + "cohere-rerank": "cohere.rerank-v3-5:0", }, } diff --git a/core/providers/bedrock/responses.go b/core/providers/bedrock/responses.go index 911865690c..1f53e12083 100644 --- a/core/providers/bedrock/responses.go +++ b/core/providers/bedrock/responses.go @@ -1500,20 +1500,30 @@ func (request *BedrockConverseRequest) ToBifrostResponsesRequest(ctx *schemas.Bi if summaryValue, ok := schemas.SafeExtractStringPointer(request.ExtraParams["reasoning_summary"]); ok { summary = summaryValue } - // Check for native output_config.effort first + var ( + effortStr string + found bool + ) + // Check for native output_config.effort first. + // output_config may be preserved as OrderedMap by the merge path. if outputConfig, ok := request.AdditionalModelRequestFields.Get("output_config"); ok { - if outputConfigMap, ok := outputConfig.(map[string]interface{}); ok { - if effortStr, ok := schemas.SafeExtractString(outputConfigMap["effort"]); ok { - var maxTokens *int - if budgetTokens, ok := schemas.SafeExtractInt(reasoningConfigMap["budget_tokens"]); ok { - maxTokens = schemas.Ptr(budgetTokens) - } - bifrostReq.Params.Reasoning = &schemas.ResponsesParametersReasoning{ - Effort: schemas.Ptr(effortStr), - MaxTokens: maxTokens, - Summary: summary, - } + if outputConfigOrderedMap, ok := schemas.SafeExtractOrderedMap(outputConfig); ok && outputConfigOrderedMap != nil { + if effortValue, exists := outputConfigOrderedMap.Get("effort"); exists { + effortStr, found = schemas.SafeExtractString(effortValue) } + } else if outputConfigMap, ok := outputConfig.(map[string]interface{}); ok { + effortStr, found = schemas.SafeExtractString(outputConfigMap["effort"]) + } + } + if found { + var maxTokens *int + if budgetTokens, ok := schemas.SafeExtractInt(reasoningConfigMap["budget_tokens"]); ok { + maxTokens = schemas.Ptr(budgetTokens) + } + bifrostReq.Params.Reasoning = &schemas.ResponsesParametersReasoning{ + Effort: schemas.Ptr(effortStr), + MaxTokens: maxTokens, + Summary: summary, } } else if maxTokens, ok := schemas.SafeExtractInt(reasoningConfigMap["budget_tokens"]); ok { // Fallback: convert budget_tokens to effort @@ -1673,6 +1683,8 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. } } + var responsesStructuredOutputTool *BedrockTool + // Map basic parameters to inference config if bifrostReq.Params != nil { inferenceConfig := &BedrockInferenceConfig{} @@ -1770,9 +1782,7 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. bedrockReq.AdditionalModelRequestFields.Set("thinking", map[string]any{ "type": "adaptive", }) - bedrockReq.AdditionalModelRequestFields.Set("output_config", map[string]any{ - "effort": effort, - }) + setOutputConfigField(bedrockReq.AdditionalModelRequestFields, "effort", effort) } else { // Opus 4.5 and older Anthropic models: budget_tokens thinking modelDefaultMaxTokens := providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, DefaultCompletionMaxTokens) @@ -1829,19 +1839,17 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. } if bifrostReq.Params.Text != nil { if bifrostReq.Params.Text.Format != nil { - responseFormatTool := convertTextFormatToTool(ctx, bifrostReq.Params.Text) - // append to bedrockTools - if responseFormatTool != nil { - if bedrockReq.ToolConfig == nil { - bedrockReq.ToolConfig = &BedrockToolConfig{} - } - bedrockReq.ToolConfig.Tools = append(bedrockReq.ToolConfig.Tools, *responseFormatTool) - // Force the model to use this specific tool (same as ChatCompletion) - bedrockReq.ToolConfig.ToolChoice = &BedrockToolChoice{ - Tool: &BedrockToolChoiceTool{ - Name: responseFormatTool.ToolSpec.Name, - }, + responseFormatTool, anthropicOutputFormat := convertTextFormatToTool(ctx, bifrostReq.Model, bifrostReq.Params.Text) + if anthropicOutputFormat != nil { + if bedrockReq.AdditionalModelRequestFields == nil { + bedrockReq.AdditionalModelRequestFields = schemas.NewOrderedMap() } + setOutputConfigField(bedrockReq.AdditionalModelRequestFields, "format", anthropicOutputFormat) + } + // Defer synthetic tool injection until after normal tool/tool_choice conversion + // so the structured-output tool is not overwritten by the later pass. + if responseFormatTool != nil { + responsesStructuredOutputTool = responseFormatTool } } } @@ -1855,7 +1863,10 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. if requestFields, exists := bifrostReq.Params.ExtraParams["additionalModelRequestFieldPaths"]; exists { if orderedFields, ok := schemas.SafeExtractOrderedMap(requestFields); ok { delete(bedrockReq.ExtraParams, "additionalModelRequestFieldPaths") - bedrockReq.AdditionalModelRequestFields = orderedFields + bedrockReq.AdditionalModelRequestFields = mergeAdditionalModelRequestFields( + bedrockReq.AdditionalModelRequestFields, + orderedFields, + ) } } @@ -1959,6 +1970,20 @@ func ToBedrockResponsesRequest(ctx *schemas.BifrostContext, bifrostReq *schemas. } } + // If text.format was converted to a synthetic tool, inject it after the normal + // tool/tool_choice pass so it is not overwritten by the above conversion. + if responsesStructuredOutputTool != nil { + if bedrockReq.ToolConfig == nil { + bedrockReq.ToolConfig = &BedrockToolConfig{} + } + bedrockReq.ToolConfig.Tools = append([]BedrockTool{*responsesStructuredOutputTool}, bedrockReq.ToolConfig.Tools...) + bedrockReq.ToolConfig.ToolChoice = &BedrockToolChoice{ + Tool: &BedrockToolChoiceTool{ + Name: responsesStructuredOutputTool.ToolSpec.Name, + }, + } + } + // Ensure tool config is present when tool content exists (similar to Chat Completions) ensureResponsesToolConfigForConversation(bifrostReq, bedrockReq) diff --git a/core/providers/bedrock/s3.go b/core/providers/bedrock/s3.go index da06e5e820..be2d0afb32 100644 --- a/core/providers/bedrock/s3.go +++ b/core/providers/bedrock/s3.go @@ -22,7 +22,6 @@ func uploadToS3( region string, bucket, key string, content []byte, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Create AWS config with credentials var cfg aws.Config @@ -47,7 +46,7 @@ func uploadToS3( } if err != nil { - return providerUtils.NewBifrostOperationError("failed to load AWS config for S3", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err) } // Create S3 client @@ -62,7 +61,7 @@ func uploadToS3( }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to S3: %s/%s", bucket, key), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err) } return nil diff --git a/core/providers/bedrock/signer.go b/core/providers/bedrock/signer.go index 9f12e3bbaf..b7e87ae8d2 100644 --- a/core/providers/bedrock/signer.go +++ b/core/providers/bedrock/signer.go @@ -280,17 +280,16 @@ func signAWSRequestFastHTTP( accessKey, secretKey string, sessionToken *string, region, service string, - providerName schemas.ModelProvider, ) *schemas.BifrostError { // Get AWS credentials if not provided if accessKey == "" && secretKey == "" { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) if err != nil { - return providerUtils.NewBifrostOperationError("failed to load aws config", err, providerName) + return providerUtils.NewBifrostOperationError("failed to load aws config", err) } creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { - return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err, providerName) + return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } accessKey = creds.AccessKeyID secretKey = creds.SecretAccessKey diff --git a/core/providers/bedrock/text.go b/core/providers/bedrock/text.go index 6ad24ee1c8..d31d716ded 100644 --- a/core/providers/bedrock/text.go +++ b/core/providers/bedrock/text.go @@ -127,8 +127,6 @@ func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -154,8 +152,6 @@ func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *s Object: "text_completion", Choices: choices, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionRequest, - Provider: schemas.Bedrock, }, } } @@ -167,11 +163,14 @@ func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionR return nil } - // Determine response format based on model - // Use ModelRequested from ExtraFields if available, otherwise use Model + // Determine response format based on resolved model identity. + // Use ResolvedModelUsed (actual provider ID) for accurate family detection, + // falling back to bifrostResp.Model, then OriginalModelRequested as a last resort. model := bifrostResp.Model - if bifrostResp.ExtraFields.ModelRequested != "" { - model = bifrostResp.ExtraFields.ModelRequested + if bifrostResp.ExtraFields.ResolvedModelUsed != "" { + model = bifrostResp.ExtraFields.ResolvedModelUsed + } else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" { + model = bifrostResp.ExtraFields.OriginalModelRequested } if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") { diff --git a/core/providers/bedrock/transport_test.go b/core/providers/bedrock/transport_test.go index 6751527b5b..1e2a447e9d 100644 --- a/core/providers/bedrock/transport_test.go +++ b/core/providers/bedrock/transport_test.go @@ -138,7 +138,7 @@ func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - _, _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") + _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") require.NotNil(t, bifrostErr, "expected error when server closes connection") assert.False(t, bifrostErr.IsBifrostError, diff --git a/core/providers/bedrock/types.go b/core/providers/bedrock/types.go index 98c46ae96f..afbfead01f 100644 --- a/core/providers/bedrock/types.go +++ b/core/providers/bedrock/types.go @@ -654,10 +654,10 @@ type BedrockMetadataEvent struct { // BedrockTitanEmbeddingRequest represents a Bedrock Titan embedding request type BedrockTitanEmbeddingRequest struct { - InputText string `json:"inputText"` // Required: Text to embed + InputText string `json:"inputText"` // Required: Text to embed + Dimensions *int `json:"dimensions,omitempty"` // Optional: 256, 512, or 1024 (titan-embed-text-v2 only) + Normalize *bool `json:"normalize,omitempty"` // Optional: normalize the embedding ExtraParams map[string]interface{} `json:"-"` - // Note: Titan models have fixed dimensions and don't support the dimensions parameter - // ExtraParams can be used for any additional model-specific parameters } // GetExtraParams implements the RequestBodyWithExtraParams interface @@ -671,6 +671,53 @@ type BedrockTitanEmbeddingResponse struct { InputTextTokenCount int `json:"inputTextTokenCount"` // Number of tokens in input } +// BedrockCohereEmbeddingContentBlock represents a single content block in a mixed input +type BedrockCohereEmbeddingContentBlock struct { + Type string `json:"type"` // "text" or "image_url" + Text *string `json:"text,omitempty"` // for type=text + ImageURL *BedrockCohereEmbeddingImageURL `json:"image_url,omitempty"` // for type=image_url +} + +// BedrockCohereEmbeddingImageURL holds the URL for an image content block +type BedrockCohereEmbeddingImageURL struct { + URL string `json:"url"` +} + +// BedrockCohereEmbeddingInput represents a mixed text+image input +type BedrockCohereEmbeddingInput struct { + Content []BedrockCohereEmbeddingContentBlock `json:"content"` +} + +// BedrockCohereEmbeddingRequest represents a Bedrock Cohere embedding request. +// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the body. +type BedrockCohereEmbeddingRequest struct { + InputType string `json:"input_type"` // Required + Texts []string `json:"texts,omitempty"` // text-only inputs + Images []string `json:"images,omitempty"` // image-only inputs (data URIs) + Inputs []BedrockCohereEmbeddingInput `json:"inputs,omitempty"` // mixed text+image inputs + EmbeddingTypes []string `json:"embedding_types,omitempty"` // e.g. ["float"] + OutputDimension *int `json:"output_dimension,omitempty"` // 256, 512, 1024, or 1536 + MaxTokens *int `json:"max_tokens,omitempty"` // max 128000 + Truncate *string `json:"truncate,omitempty"` // NONE, LEFT, or RIGHT + ExtraParams map[string]interface{} `json:"-"` +} + +// GetExtraParams implements the RequestBodyWithExtraParams interface +func (req *BedrockCohereEmbeddingRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// BedrockCohereEmbeddingResponse handles both Bedrock Cohere embedding response shapes. +// When embedding_types is not set, Bedrock returns embeddings as a raw [][]float32 +// ("embeddings_floats"). When embedding_types is set, it returns an object with typed +// arrays ("embeddings_by_type"). Using json.RawMessage defers parsing until we know the shape. +type BedrockCohereEmbeddingResponse struct { + ID string `json:"id"` + Embeddings json.RawMessage `json:"embeddings"` + ResponseType string `json:"response_type"` + Texts []string `json:"texts,omitempty"` +} + const TaskTypeTextImage = "TEXT_IMAGE" const TaskTypeImageVariation = "IMAGE_VARIATION" const TaskTypeInpainting = "INPAINTING" @@ -763,11 +810,79 @@ type BedrockBackgroundRemovalParams struct { Image string `json:"image"` // Base64-encoded image } -// BedrockImageGenerationResponse represents a Bedrock image generation response +// StabilityAIImageGenerationRequest represents the request format for Stability AI models on Bedrock +// (e.g. stability.stable-image-core-v1:1, stability.stable-image-ultra-v1:1) +type StabilityAIImageGenerationRequest struct { + Prompt string `json:"prompt"` + AspectRatio *string `json:"aspect_ratio,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + Seed *int `json:"seed,omitempty"` + NegativePrompt *string `json:"negative_prompt,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +// GetExtraParams implements the RequestBodyWithExtraParams interface +func (req *StabilityAIImageGenerationRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// StabilityAIImageEditRequest is the flat JSON body for Stability AI image-edit models on Bedrock. +// Only the fields valid for the detected task type are populated. +type StabilityAIImageEditRequest struct { + // Shared params + Image *string `json:"image,omitempty"` // base64, primary input image + Prompt *string `json:"prompt,omitempty"` + NegativePrompt *string `json:"negative_prompt,omitempty"` + Seed *int `json:"seed,omitempty"` + OutputFormat *string `json:"output_format,omitempty"` + StylePreset *string `json:"style_preset,omitempty"` + Mask *string `json:"mask,omitempty"` // base64 mask image + GrowMask *int `json:"grow_mask,omitempty"` + + // Outpaint + Left *int `json:"left,omitempty"` + Right *int `json:"right,omitempty"` + Up *int `json:"up,omitempty"` + Down *int `json:"down,omitempty"` + + // Upscale-creative / upscale-conservative / outpaint + Creativity *float64 `json:"creativity,omitempty"` + + // Recolor + SelectPrompt *string `json:"select_prompt,omitempty"` + + // Search-replace + SearchPrompt *string `json:"search_prompt,omitempty"` + + // Control-sketch / control-structure + ControlStrength *float64 `json:"control_strength,omitempty"` + + // Style-guide + AspectRatio *string `json:"aspect_ratio,omitempty"` + Fidelity *float64 `json:"fidelity,omitempty"` + + // Style-transfer (uses different image field names) + InitImage *string `json:"init_image,omitempty"` + StyleImage *string `json:"style_image,omitempty"` + StyleStrength *float64 `json:"style_strength,omitempty"` + CompositionFidelity *float64 `json:"composition_fidelity,omitempty"` + ChangeStrength *float64 `json:"change_strength,omitempty"` + + ExtraParams map[string]interface{} `json:"-"` +} + +func (req *StabilityAIImageEditRequest) GetExtraParams() map[string]interface{} { + return req.ExtraParams +} + +// BedrockImageGenerationResponse represents a Bedrock image generation response. +// The Seeds and FinishReasons fields are populated by Stability AI edit models only. type BedrockImageGenerationResponse struct { - Images []string `json:"images"` // list of Base64 encoded images - MaskImage string `json:"maskImage"` // Base64 encoded mask image (optional) - Error string `json:"error"` // error message (if present) + Images []string `json:"images"` // list of Base64 encoded images + MaskImage string `json:"maskImage,omitempty"` // Base64 encoded mask image (optional) + Error string `json:"error,omitempty"` // error message (if present) + Seeds []int `json:"seeds,omitempty"` // Stability AI: seeds used per image + FinishReasons []*string `json:"finish_reasons,omitempty"` // Stability AI: finish reason per image (may be null) } // ==================== MODELS TYPES ==================== @@ -960,6 +1075,37 @@ type BedrockInvokeRequest struct { FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` PresencePenalty *float64 `json:"presence_penalty,omitempty"` + // ==================== BEDROCK IMAGE GEN / EDIT / VARIATION (Titan/Nova Canvas) ==================== + + TaskType *string `json:"taskType,omitempty"` + TextToImageParams *BedrockTextToImageParams `json:"textToImageParams,omitempty"` + ImageVariationParams *BedrockImageVariationParams `json:"imageVariationParams,omitempty"` + InPaintingParams *BedrockInPaintingParams `json:"inPaintingParams,omitempty"` + OutPaintingParams *BedrockOutPaintingParams `json:"outPaintingParams,omitempty"` + BackgroundRemovalParams *BedrockBackgroundRemovalParams `json:"backgroundRemovalParams,omitempty"` + ImageGenerationConfig *ImageGenerationConfig `json:"imageGenerationConfig,omitempty"` + + // ==================== STABILITY AI IMAGE ==================== + + // Image is the base64-encoded input image (SA edit / variation) + Image *string `json:"image,omitempty"` + Mask *string `json:"mask,omitempty"` // base64 mask for inpainting + NegativePrompt *string `json:"negative_prompt,omitempty"` // SA gen / edit + AspectRatio *string `json:"aspect_ratio,omitempty"` // SA gen + OutputFormat *string `json:"output_format,omitempty"` // SA gen + Seed *int `json:"seed,omitempty"` // SA gen / edit + + // ==================== EMBEDDINGS ==================== + + InputText string `json:"inputText,omitempty"` // Titan embed + Texts []string `json:"texts,omitempty"` // Cohere embed + InputType *string `json:"input_type,omitempty"` // Cohere embed + Normalize *bool `json:"normalize,omitempty"` // Titan embed v2 + Dimensions *int `json:"dimensions,omitempty"` // Titan embed v2 + EmbeddingTypes []string `json:"embedding_types,omitempty"` // Cohere embed: ["float","int8","uint8","binary","ubinary"] + OutputDimension *int `json:"output_dimension,omitempty"` // Cohere embed: 256, 512, 1024, 1536 + Inputs []BedrockCohereEmbeddingInput `json:"inputs,omitempty"` // Cohere embed: mixed text+image inputs + // ==================== INTERNAL ==================== Stream bool `json:"-"` ExtraParams map[string]interface{} `json:"-"` @@ -970,3 +1116,15 @@ type BedrockCohereRMessage struct { Role string `json:"role"` // "USER" or "CHATBOT" Message string `json:"message"` // Message content } + +// BedrockInvokeEmbeddingResp is the Titan single-embedding invoke response format. +type BedrockInvokeEmbeddingResp struct { + Embedding []float32 `json:"embedding"` + InputTextTokenCount int `json:"inputTextTokenCount"` +} + +// BedrockInvokeCohereEmbeddingResp is the Cohere multi-embedding invoke response format. +type BedrockInvokeCohereEmbeddingResp struct { + Embeddings [][]float32 `json:"embeddings"` + ResponseType string `json:"response_type"` +} diff --git a/core/providers/bedrock/utils.go b/core/providers/bedrock/utils.go index 1a8885ca54..4eb48452a0 100644 --- a/core/providers/bedrock/utils.go +++ b/core/providers/bedrock/utils.go @@ -74,14 +74,60 @@ func convertChatParameters(ctx *schemas.BifrostContext, bifrostReq *schemas.Bifr bedrockReq.InferenceConfig = inferenceConfig } - // Check for response_format and convert to tool - responseFormatTool := convertResponseFormatToTool(ctx, bifrostReq.Params) + // Handle structured output conversion: + // - Anthropic models on Bedrock use native output_config.format + // - Other models keep the response_format->tool conversion. + responseFormatTool, anthropicOutputFormat := convertResponseFormatToTool(ctx, bifrostReq.Model, bifrostReq.Params) + if anthropicOutputFormat != nil { + if bedrockReq.AdditionalModelRequestFields == nil { + bedrockReq.AdditionalModelRequestFields = schemas.NewOrderedMap() + } + setOutputConfigField(bedrockReq.AdditionalModelRequestFields, "format", anthropicOutputFormat) + } - // Convert tool config - if toolConfig := convertToolConfig(bifrostReq.Model, bifrostReq.Params); toolConfig != nil { + // Filter provider-unsupported server tools once; both convertToolConfig and + // collectBedrockServerTools consume the same filtered set, and + // buildBedrockServerToolChoice resolves pinned names against it. + filteredTools, _ := anthropic.ValidateChatToolsForProvider(bifrostReq.Params.Tools, schemas.Bedrock) + + // Convert tool config (function/custom tools → Converse toolConfig.tools). + if toolConfig := convertToolConfigFromFiltered(bifrostReq.Model, bifrostReq.Params, filteredTools); toolConfig != nil { bedrockReq.ToolConfig = toolConfig } + // Tunnel Bedrock-supported Anthropic server tools through Converse's + // additionalModelRequestFields (model-specific passthrough) since Converse's + // typed toolSpec shape can't express server tools like bash_*, computer_*, + // memory_*, text_editor_*, tool_search_tool_*. Fields injected: + // - tools: array of server tools in Anthropic-native shape, which + // Bedrock merges into the underlying Messages request. + // - anthropic_beta: activation header(s) for the relevant server tool, in + // addition to whatever the existing anthropic-beta HTTP + // header path in bedrock.go:214/447 already forwards. + // - tool_choice: Anthropic-native pin for a kept server tool OR an + // any/required contract when only server tools are + // present. Emitted only when Converse's typed + // toolConfig.toolChoice path can't express the intent + // (see buildBedrockServerToolChoice). + if serverTools, betaHeaders := collectBedrockServerToolsFromFiltered(filteredTools); len(serverTools) > 0 { + if bedrockReq.AdditionalModelRequestFields == nil { + bedrockReq.AdditionalModelRequestFields = schemas.NewOrderedMap() + } + bedrockReq.AdditionalModelRequestFields.Set("tools", serverTools) + if len(betaHeaders) > 0 { + bedrockReq.AdditionalModelRequestFields.Set("anthropic_beta", betaHeaders) + } + // Skip the tunneled tool_choice when response_format forces the synthetic + // bf_so_* tool at lines 263-275 below; otherwise Bedrock receives two + // conflicting tool-choice directives and the structured-output contract + // can silently break. + if responseFormatTool == nil { + if choice, ok := buildBedrockServerToolChoice(bifrostReq.Params, filteredTools); ok { + bedrockReq.AdditionalModelRequestFields.Set("tool_choice", choice) + } + } + } + // Convert reasoning config if bifrostReq.Params.Reasoning != nil { if bedrockReq.AdditionalModelRequestFields == nil { @@ -190,9 +236,7 @@ func convertChatParameters(ctx *schemas.BifrostContext, bifrostReq *schemas.Bifr bedrockReq.AdditionalModelRequestFields.Set("thinking", map[string]any{ "type": "adaptive", }) - bedrockReq.AdditionalModelRequestFields.Set("output_config", map[string]any{ - "effort": effort, - }) + setOutputConfigField(bedrockReq.AdditionalModelRequestFields, "effort", effort) } else { // Opus 4.5 and older models: budget_tokens thinking budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, anthropic.MinimumReasoningMaxTokens, maxTokens) @@ -270,7 +314,10 @@ func convertChatParameters(ctx *schemas.BifrostContext, bifrostReq *schemas.Bifr if requestFields, exists := bifrostReq.Params.ExtraParams["additionalModelRequestFieldPaths"]; exists { if orderedFields, ok := schemas.SafeExtractOrderedMap(requestFields); ok { delete(bedrockReq.ExtraParams, "additionalModelRequestFieldPaths") - bedrockReq.AdditionalModelRequestFields = orderedFields + bedrockReq.AdditionalModelRequestFields = mergeAdditionalModelRequestFields( + bedrockReq.AdditionalModelRequestFields, + orderedFields, + ) } } @@ -341,6 +388,103 @@ func convertChatParameters(ctx *schemas.BifrostContext, bifrostReq *schemas.Bifr return nil } +// setOutputConfigField upserts a single key in additionalModelRequestFields.output_config +// while preserving any existing output_config keys (e.g. keep "format" when adding "effort"). +func setOutputConfigField(fields *schemas.OrderedMap, key string, value any) { + if fields == nil { + return + } + current := schemas.NewOrderedMap() + if existing, ok := fields.Get("output_config"); ok { + if om, ok := toOrderedMap(existing); ok && om != nil { + current = om + } + } + current.Set(key, value) + fields.Set("output_config", current) +} + +func mergeAdditionalModelRequestFields(existing, incoming *schemas.OrderedMap) *schemas.OrderedMap { + if existing == nil { + if incoming == nil { + return nil + } + return incoming.Clone() + } + if incoming == nil { + return existing + } + + merged := existing.Clone() + incoming.Range(func(key string, value interface{}) bool { + if key == "output_config" { + current := schemas.NewOrderedMap() + if existingValue, ok := merged.Get(key); ok { + if om, ok := toOrderedMap(existingValue); ok && om != nil { + current = om + } + } + if incomingMap, ok := toOrderedMap(value); ok && incomingMap != nil { + mergeOrderedMapInto(current, incomingMap) + merged.Set(key, current) + } else { + merged.Set(key, value) + } + return true + } + merged.Set(key, value) + return true + }) + return merged +} + +func toOrderedMap(v any) (*schemas.OrderedMap, bool) { + switch m := v.(type) { + case *schemas.OrderedMap: + if m == nil { + return nil, false + } + return m.Clone(), true + case schemas.OrderedMap: + return m.Clone(), true + case map[string]interface{}: + // Fallback for callers that still provide a plain map. Order cannot be + // reconstructed here, but keeping this path preserves compatibility. + return schemas.OrderedMapFromMap(m), true + default: + return nil, false + } +} + +// mergeOrderedMapInto deep-merges src into dst. Nested OrderedMap values are +// merged recursively; non-map values from src overwrite dst. Existing key order +// is preserved and newly introduced keys are appended in source order. +func mergeOrderedMapInto(dst, src *schemas.OrderedMap) { + if dst == nil || src == nil { + return + } + src.Range(func(key string, srcVal interface{}) bool { + if srcMap, ok := toOrderedMap(srcVal); ok && srcMap != nil { + if dstVal, exists := dst.Get(key); exists { + if dstMap, ok := toOrderedMap(dstVal); ok && dstMap != nil { + mergeOrderedMapInto(dstMap, srcMap) + dst.Set(key, dstMap) + return true + } + } + } + dst.Set(key, srcVal) + return true + }) +} + +func newAnthropicOutputFormatOrderedMap(schemaObj any) *schemas.OrderedMap { + return schemas.NewOrderedMapFromPairs( + schemas.KV("type", "json_schema"), + schemas.KV("schema", schemaObj), + ) +} + // ensureChatToolConfigForConversation ensures toolConfig is present when tool content exists func ensureChatToolConfigForConversation(bifrostReq *schemas.BifrostChatRequest, bedrockReq *BedrockConverseRequest) { if bedrockReq.ToolConfig != nil { @@ -825,44 +969,70 @@ func convertImageToBedrockSource(imageURL string) (*BedrockImageSource, error) { // convertResponseFormatToTool converts a response_format parameter to a Bedrock tool // Returns nil if no response_format is present or if it's not a json_schema type // Ref: https://aws.amazon.com/blogs/machine-learning/structured-data-response-with-amazon-bedrock-prompt-engineering-and-tool-use/ -func convertResponseFormatToTool(ctx *schemas.BifrostContext, params *schemas.ChatParameters) *BedrockTool { +func convertResponseFormatToTool( + ctx *schemas.BifrostContext, + model string, + params *schemas.ChatParameters, +) (*BedrockTool, any) { if params == nil || params.ResponseFormat == nil { - return nil + return nil, nil } - // ResponseFormat is stored as interface{}, need to parse it - responseFormatMap, ok := (*params.ResponseFormat).(map[string]interface{}) - if !ok { - return nil + responseFormatMap, ok := schemas.SafeExtractOrderedMap(*params.ResponseFormat) + if !ok || responseFormatMap == nil { + return nil, nil } // Check if type is "json_schema" - formatType, ok := responseFormatMap["type"].(string) + formatTypeRaw, ok := responseFormatMap.Get("type") + if !ok { + return nil, nil + } + formatType, ok := schemas.SafeExtractString(formatTypeRaw) if !ok || formatType != "json_schema" { - return nil + return nil, nil } // Extract json_schema object - jsonSchemaObj, ok := responseFormatMap["json_schema"].(map[string]interface{}) + jsonSchemaRaw, ok := responseFormatMap.Get("json_schema") if !ok { - return nil + return nil, nil } - - // Extract name and schema - toolName, ok := jsonSchemaObj["name"].(string) - if !ok || toolName == "" { - toolName = "json_response" + jsonSchemaObj, ok := schemas.SafeExtractOrderedMap(jsonSchemaRaw) + if !ok || jsonSchemaObj == nil { + return nil, nil } - schemaObj, ok := jsonSchemaObj["schema"].(map[string]interface{}) + schemaObj, ok := jsonSchemaObj.Get("schema") if !ok { - return nil + return nil, nil + } + + // Anthropic Bedrock supports native output_config.format. Keep this provider-specific + // conversion encapsulated here, and let caller just apply returned values. + if schemas.IsAnthropicModel(model) { + return nil, newAnthropicOutputFormatOrderedMap(schemaObj) + } + + // Extract name and schema + toolNameRaw, hasName := jsonSchemaObj.Get("name") + toolName, ok := schemas.SafeExtractString(toolNameRaw) + if !hasName || !ok || toolName == "" { + toolName = "json_response" } // Extract description from schema if available description := "Returns structured JSON output" - if desc, ok := schemaObj["description"].(string); ok && desc != "" { - description = desc + if schemaMap, ok := schemas.SafeExtractOrderedMap(schemaObj); ok && schemaMap != nil { + if descRaw, hasDesc := schemaMap.Get("description"); hasDesc { + if desc, ok := schemas.SafeExtractString(descRaw); ok && desc != "" { + description = desc + } + } + } else if schemaMap, ok := schemaObj.(map[string]interface{}); ok { + if desc, ok := schemaMap["description"].(string); ok && desc != "" { + description = desc + } } // set bifrost context key structured output tool name @@ -872,7 +1042,7 @@ func convertResponseFormatToTool(ctx *schemas.BifrostContext, params *schemas.Ch // Create the Bedrock tool schemaObjBytes, err := providerUtils.MarshalSorted(schemaObj) if err != nil { - return nil + return nil, nil } return &BedrockTool{ ToolSpec: &BedrockToolSpec{ @@ -882,18 +1052,19 @@ func convertResponseFormatToTool(ctx *schemas.BifrostContext, params *schemas.Ch JSON: json.RawMessage(schemaObjBytes), }, }, - } + }, nil } -// convertTextFormatToTool converts a text config to a Bedrock tool for structured outpute -func convertTextFormatToTool(ctx *schemas.BifrostContext, textConfig *schemas.ResponsesTextConfig) *BedrockTool { +// convertTextFormatToTool converts a Responses text.format config to either a +// synthetic Bedrock tool or an Anthropic-native output_config.format value. +func convertTextFormatToTool(ctx *schemas.BifrostContext, model string, textConfig *schemas.ResponsesTextConfig) (*BedrockTool, any) { if textConfig == nil || textConfig.Format == nil { - return nil + return nil, nil } format := textConfig.Format if format.Type != "json_schema" { - return nil + return nil, nil } toolName := "json_response" @@ -902,23 +1073,24 @@ func convertTextFormatToTool(ctx *schemas.BifrostContext, textConfig *schemas.Re } description := "Returns structured JSON output" + if format.JSONSchema == nil || format.JSONSchema.Schema == nil { + return nil, nil // Schema is required for structured output + } if format.JSONSchema.Description != nil { description = *format.JSONSchema.Description } + schemaObj := *format.JSONSchema.Schema + + if schemas.IsAnthropicModel(model) { + return nil, newAnthropicOutputFormatOrderedMap(schemaObj) + } toolName = fmt.Sprintf("bf_so_%s", toolName) ctx.SetValue(schemas.BifrostContextKeyStructuredOutputToolName, toolName) - var schemaObj any - if format.JSONSchema != nil { - schemaObj = *format.JSONSchema - } else { - return nil // Schema is required for Bedrock tooling - } - schemaObjBytes2, err := providerUtils.MarshalSorted(schemaObj) if err != nil { - return nil + return nil, nil } return &BedrockTool{ ToolSpec: &BedrockToolSpec{ @@ -928,7 +1100,7 @@ func convertTextFormatToTool(ctx *schemas.BifrostContext, textConfig *schemas.Re JSON: json.RawMessage(schemaObjBytes2), }, }, - } + }, nil } // convertInferenceConfig converts Bifrost parameters to Bedrock inference config @@ -953,14 +1125,225 @@ func convertInferenceConfig(params *schemas.ChatParameters) *BedrockInferenceCon return &config } -// convertToolConfig converts Bifrost tools to Bedrock tool config +// collectBedrockServerTools partitions kept tools into the function/custom +// set (which convertToolConfig materializes into Converse's toolConfig.tools) +// and the kept-server-tool set (which cannot be expressed via Converse's +// typed toolSpec slot and must be tunneled via additionalModelRequestFields). +// +// Returns: +// - serverTools: each ChatTool serialized to its Anthropic-native JSON shape +// (e.g. `{"type":"computer_20251124","name":"computer","display_width_px":1280}`) +// ready to drop into additionalModelRequestFields.tools. Per the comment on +// ChatTool in core/schemas/chatcompletions.go:340-351, the default marshaler +// produces this shape directly — no custom codec needed. +// - betaHeaders: anthropic-beta header values derived from the server tool +// Types, filtered through FilterBetaHeadersForProvider(schemas.Bedrock) so +// only Bedrock-approved headers survive. Only high-confidence mappings are +// derived here (computer_* and memory_*); callers relying on other betas +// (e.g. text_editor-specific headers) should continue supplying them via +// extra-headers / ctx — they flow through bedrock.go's existing +// anthropic-beta HTTP header path. +// +// Unsupported server tools (e.g. web_search on Bedrock) are dropped upstream +// by ValidateChatToolsForProvider, so they never reach this helper. +func collectBedrockServerTools(params *schemas.ChatParameters) (serverTools []json.RawMessage, betaHeaders []string) { + if params == nil || len(params.Tools) == 0 { + return nil, nil + } + filtered, _ := anthropic.ValidateChatToolsForProvider(params.Tools, schemas.Bedrock) + return collectBedrockServerToolsFromFiltered(filtered) +} + +// collectBedrockServerToolsFromFiltered is the inner variant that accepts a +// pre-filtered tool set (already run through ValidateChatToolsForProvider). +// convertChatParameters filters once and passes the result to both this helper +// and convertToolConfigFromFiltered to avoid re-filtering twice per request. +func collectBedrockServerToolsFromFiltered(filtered []schemas.ChatTool) (serverTools []json.RawMessage, betaHeaders []string) { + if len(filtered) == 0 { + return nil, nil + } + seenBeta := make(map[string]struct{}) + for _, tool := range filtered { + if tool.Function != nil || tool.Custom != nil { + continue + } + bytes, err := providerUtils.MarshalSorted(tool) + if err != nil { + continue + } + serverTools = append(serverTools, json.RawMessage(bytes)) + for _, h := range deriveBedrockBetaHeadersForToolType(string(tool.Type)) { + if _, ok := seenBeta[h]; ok { + continue + } + seenBeta[h] = struct{}{} + betaHeaders = append(betaHeaders, h) + } + } + if len(betaHeaders) > 0 { + // Gate through the Bedrock-approved beta-header list. + betaHeaders = anthropic.FilterBetaHeadersForProvider(betaHeaders, schemas.Bedrock) + } + return serverTools, betaHeaders +} + +// buildBedrockServerToolChoice emits an Anthropic-native tool_choice value +// for tunneling through additionalModelRequestFields.tool_choice ONLY when +// Converse's typed toolConfig.toolChoice path cannot express the caller's +// intent: +// +// - Named pin of a kept server tool: convertToolConfig builds toolConfig.tools +// from function/custom tools only, and its reconciliation (around line +// 1274) drops any named pin that doesn't match an entry in that slice. +// Server-tool names never appear there, so a legitimate pin like +// tool_choice={type:"function", function:{name:"computer"}} gets silently +// nuked. We tunnel {"type":"tool","name":"computer"} instead so the +// forced-tool contract reaches Anthropic via Bedrock's merge. +// - any/required with only server tools: convertToolConfig returns nil +// entirely (empty-slice guard since bedrockTools is empty), so the typed +// "any" contract is lost. We tunnel {"type":"any"} to preserve it. +// +// Returns (nil, false) when the typed Converse path is adequate (auto/none, +// function-tool pin, any with function tools present, or a pin whose name +// doesn't match any kept server tool). +// +// Anthropic tool_choice shape ref: platform.claude.com/docs/en/docs/agents-and-tools/tool-use/define-tools +// ("Controlling Claude's output / Forcing tool use" — four options: +// auto, any, tool, none; forced tool shape is {"type":"tool","name":"..."}). +func buildBedrockServerToolChoice(params *schemas.ChatParameters, filtered []schemas.ChatTool) (json.RawMessage, bool) { + if params == nil || params.ToolChoice == nil { + return nil, false + } + + // Resolve effective type and optional pinned name from either the string + // or struct representation of ChatToolChoice. + var ( + choiceType schemas.ChatToolChoiceType + pinnedName string + ) + if params.ToolChoice.ChatToolChoiceStr != nil { + choiceType = schemas.ChatToolChoiceType(*params.ToolChoice.ChatToolChoiceStr) + } else if params.ToolChoice.ChatToolChoiceStruct != nil { + s := params.ToolChoice.ChatToolChoiceStruct + choiceType = s.Type + if s.Function != nil { + pinnedName = s.Function.Name + } else if s.Custom != nil { + pinnedName = s.Custom.Name + } + } else { + return nil, false + } + + // Partition kept tools: server-tool name set, plus whether any + // function/custom tool is present. + serverToolNames := make(map[string]struct{}) + hasFunctionOrCustom := false + for _, tool := range filtered { + if tool.Function != nil || tool.Custom != nil { + hasFunctionOrCustom = true + continue + } + if tool.Name != "" { + serverToolNames[tool.Name] = struct{}{} + } + } + + switch choiceType { + case schemas.ChatToolChoiceTypeFunction, schemas.ChatToolChoiceTypeCustom, + schemas.ChatToolChoiceType("tool"): + // Only tunnel when the pinned name matches a kept server tool. + // Function/custom pins stay on the typed Converse path. + if pinnedName == "" { + return nil, false + } + if _, ok := serverToolNames[pinnedName]; !ok { + return nil, false + } + bytes, err := providerUtils.MarshalSorted(map[string]any{ + "type": "tool", + "name": pinnedName, + }) + if err != nil { + return nil, false + } + return json.RawMessage(bytes), true + + case schemas.ChatToolChoiceTypeAny, schemas.ChatToolChoiceTypeRequired: + // When function/custom tools are present, Converse's typed + // toolChoice.any handles the any contract — don't double-emit. + if hasFunctionOrCustom || len(serverToolNames) == 0 { + return nil, false + } + bytes, err := providerUtils.MarshalSorted(map[string]any{"type": "any"}) + if err != nil { + return nil, false + } + return json.RawMessage(bytes), true + + default: + // auto, none, allowed_tools, empty, unknown — no tunneling. + return nil, false + } +} + +// deriveBedrockBetaHeadersForToolType maps an Anthropic server-tool Type string +// to the anthropic-beta header(s) Bedrock requires for the feature to activate. +// Only high-confidence mappings are encoded here — both are anchored in +// core/providers/anthropic/types.go (cite: B-header comments around lines 178-183). +// Unknown prefixes return nil; callers can still inject betas via extra-headers. +func deriveBedrockBetaHeadersForToolType(toolType string) []string { + switch { + case strings.HasPrefix(toolType, "computer_"): + // computer_YYYYMMDD → computer-use-YYYY-MM-DD (Bedrock B-header). + rest := strings.TrimPrefix(toolType, "computer_") + if len(rest) == 8 { + return []string{"computer-use-" + rest[0:4] + "-" + rest[4:6] + "-" + rest[6:8]} + } + return nil + case strings.HasPrefix(toolType, "memory_"): + // Memory activates via the context-management bundle on Bedrock + // (see anthropic/types.go:179 — "context-management-2025-06-27 per + // B-header (bundles memory)"). + return []string{"context-management-2025-06-27"} + } + return nil +} + +// convertToolConfig converts Bifrost tools to Bedrock tool config. +// +// Responsibilities (split from collectBedrockServerTools): +// - Filters server tools the target provider doesn't support via +// ValidateChatToolsForProvider (e.g. web_search on Bedrock per cited +// docs — AWS user guide beta-header list, Anthropic overview feature +// table). Silently stripped. +// - Materializes function/custom tools into Converse's typed toolConfig.tools. +// Kept server tools (bash_*, computer_*, memory_*, text_editor_*, +// tool_search_tool_*) are NOT emitted here — they are handled separately +// by collectBedrockServerTools → additionalModelRequestFields.tools, since +// Converse's toolSpec slot has no shape for them. +// - Returns nil instead of an empty-slice ToolConfig, since Bedrock's +// Converse API rejects `"toolConfig": {"tools": []}` with a 400. func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToolConfig { - if len(params.Tools) == 0 { + if params == nil || len(params.Tools) == 0 { + return nil + } + // Strip unsupported server tools before the conversion loop. + filtered, _ := anthropic.ValidateChatToolsForProvider(params.Tools, schemas.Bedrock) + return convertToolConfigFromFiltered(model, params, filtered) +} + +// convertToolConfigFromFiltered is the inner variant that accepts a +// pre-filtered tool set. convertChatParameters uses this to avoid filtering +// twice (once here, once in collectBedrockServerTools). The public +// convertToolConfig entry point is a thin wrapper preserved for tests. +func convertToolConfigFromFiltered(model string, params *schemas.ChatParameters, filtered []schemas.ChatTool) *BedrockToolConfig { + if params == nil { return nil } var bedrockTools []BedrockTool - for _, tool := range params.Tools { + for _, tool := range filtered { if tool.Function != nil { // Serialize the parameters (or a default empty schema) to json.RawMessage var schemaObjectBytes []byte @@ -986,7 +1369,7 @@ func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToo bedrockTool := BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: tool.Function.Name, - Description: schemas.Ptr(description), + Description: new(description), InputSchema: BedrockToolInputSchema{ JSON: json.RawMessage(schemaObjectBytes), }, @@ -1004,6 +1387,15 @@ func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToo } } + // Empty-guard: Bedrock's Converse API rejects {"toolConfig": {"tools": []}} + // with a 400 "The provided request is not valid". If every incoming tool + // was filtered out above (e.g. only server tools the target provider + // doesn't support), omit ToolConfig entirely so the request is valid and + // the model simply answers without tool access. + if len(bedrockTools) == 0 { + return nil + } + toolConfig := &BedrockToolConfig{ Tools: bedrockTools, } @@ -1012,7 +1404,28 @@ func convertToolConfig(model string, params *schemas.ChatParameters) *BedrockToo if params.ToolChoice != nil { toolChoice := convertToolChoice(*params.ToolChoice) if toolChoice != nil { - toolConfig.ToolChoice = toolChoice + // Reconcile: if the choice forces a specific tool by name, + // verify that name still exists in the filtered tool set. + // Without this, a caller that pinned a server tool we just + // stripped (e.g. web_search on Bedrock) would ship a + // toolChoice.tool.name ∉ tools, and Bedrock's Converse API + // rejects that with a 400 ValidationException — defeating + // the silent-strip contract. + if toolChoice.Tool != nil && toolChoice.Tool.Name != "" { + found := false + for _, bt := range bedrockTools { + if bt.ToolSpec != nil && bt.ToolSpec.Name == toolChoice.Tool.Name { + found = true + break + } + } + if !found { + toolChoice = nil + } + } + if toolChoice != nil { + toolConfig.ToolChoice = toolChoice + } } } diff --git a/core/providers/cerebras/cerebras.go b/core/providers/cerebras/cerebras.go index 665c2e81bc..c32dcd7374 100644 --- a/core/providers/cerebras/cerebras.go +++ b/core/providers/cerebras/cerebras.go @@ -178,9 +178,6 @@ func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/cohere/chat.go b/core/providers/cohere/chat.go index 1623d23737..fc19d18919 100644 --- a/core/providers/cohere/chat.go +++ b/core/providers/cohere/chat.go @@ -372,8 +372,6 @@ func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas }, Created: int(time.Now().Unix()), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Cohere, }, } diff --git a/core/providers/cohere/cohere.go b/core/providers/cohere/cohere.go index 4386a55d11..1e5d50e087 100644 --- a/core/providers/cohere/cohere.go +++ b/core/providers/cohere/cohere.go @@ -155,7 +155,7 @@ func (provider *CohereProvider) buildRequestURL(ctx *schemas.BifrostContext, def // completeRequest sends a request to Cohere's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string, meta *providerUtils.RequestMetadata) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { +func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jsonData []byte, url string, key string) ([]byte, time.Duration, map[string]string, *schemas.BifrostError) { // Create the request with the JSON body req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -199,10 +199,10 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, latency, providerResponseHeaders, parseCohereError(resp, meta) + return nil, latency, providerResponseHeaders, parseCohereError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.GetProviderKey(), provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, latency, providerResponseHeaders, decodeErr } @@ -217,8 +217,6 @@ func (provider *CohereProvider) completeRequest(ctx *schemas.BifrostContext, jso // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -234,7 +232,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Parse and add query parameters u, err := url.Parse(baseURL) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse request URL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse request url", err) } q := u.Query() @@ -269,15 +267,12 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseCohereError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse Cohere list models response @@ -288,7 +283,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key } // Convert Cohere v2 response to Bifrost response - response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := cohereResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -352,17 +347,12 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ChatCompletionRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -375,9 +365,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -396,9 +383,6 @@ func (provider *CohereProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -424,7 +408,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, err } - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -435,8 +418,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } reqBody.Stream = schemas.Ptr(true) return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -486,9 +468,9 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -497,11 +479,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -518,11 +496,12 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -560,7 +539,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -582,11 +561,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext response, bifrostErr, isLastChunk := event.ToBifrostChatCompletionStream() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -594,11 +568,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx *schemas.BifrostContext if response != nil { response.ID = responseID response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -638,18 +609,13 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereResponsesRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Convert to Cohere v2 request - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/chat", schemas.ResponsesRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -662,9 +628,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -685,9 +648,6 @@ func (provider *CohereProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -711,7 +671,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, err } - providerName := provider.GetProviderKey() // Convert to Cohere v2 request and add streaming jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -725,8 +684,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos reqBody.Stream = schemas.Ptr(true) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -774,9 +732,9 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -785,11 +743,7 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseCohereError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseCohereError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -806,11 +760,12 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -852,8 +807,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - provider.logger.Warn("Error reading %s stream: %v", providerName, readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + provider.logger.Warn("Error reading stream: %v", readErr) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -873,11 +828,6 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos responses, bifrostErr, isLastChunk := event.ToBifrostResponsesStream(chunkIndex, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) break @@ -886,11 +836,8 @@ func (provider *CohereProvider) ResponsesStream(ctx *schemas.BifrostContext, pos for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() chunkIndex++ @@ -934,18 +881,13 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereEmbeddingRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Create Bifrost request for conversion - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/embed", schemas.EmbeddingRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -958,9 +900,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -979,9 +918,6 @@ func (provider *CohereProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := response.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1010,17 +946,12 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereRerankRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } - responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue(), &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + responseBody, latency, providerResponseHeaders, err := provider.completeRequest(ctx, jsonBody, provider.buildRequestURL(ctx, "/v2/rerank", schemas.RerankRequest), key.Value.GetValue()) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1033,9 +964,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1056,9 +984,6 @@ func (provider *CohereProvider) Rerank(ctx *schemas.BifrostContext, key schemas. bifrostResponse.Model = request.Model // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1216,16 +1141,12 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToCohereCountTokensRequest(request) - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1235,11 +1156,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonBody, provider.buildRequestURL(ctx, "/v1/tokenize", schemas.CountTokensRequest), key.Value.GetValue(), - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }, ) if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) @@ -1253,9 +1169,6 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch return &schemas.BifrostCountTokensResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1277,12 +1190,9 @@ func (provider *CohereProvider) CountTokens(ctx *schemas.BifrostContext, key sch bifrostResponse := cohereResponse.ToBifrostCountTokensResponse(request.Model) if bifrostResponse == nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil Cohere count tokens response"), providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, fmt.Errorf("nil cohere count tokens response")), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.CountTokensRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders diff --git a/core/providers/cohere/cohere_test.go b/core/providers/cohere/cohere_test.go index f2fca1028f..23c73911bc 100644 --- a/core/providers/cohere/cohere_test.go +++ b/core/providers/cohere/cohere_test.go @@ -32,27 +32,27 @@ func TestCohere(t *testing.T) { RerankModel: "rerank-v3.5", ReasoningModel: "command-a-reasoning-08-2025", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not typical for Cohere - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not typical for Cohere + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, // May not support automatic - ImageURL: false, // Supported by c4ai-aya-vision-8b model - ImageBase64: true, // Supported by c4ai-aya-vision-8b model - MultipleImages: false, // Supported by c4ai-aya-vision-8b model - FileBase64: false, // Not supported - FileURL: false, // Not supported - CompleteEnd2End: false, - Embedding: true, - Rerank: true, - Reasoning: true, - ListModels: true, - CountTokens: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, // May not support automatic + ImageURL: false, // Supported by c4ai-aya-vision-8b model + ImageBase64: true, // Supported by c4ai-aya-vision-8b model + MultipleImages: false, // Supported by c4ai-aya-vision-8b model + FileBase64: false, // Not supported + FileURL: false, // Not supported + CompleteEnd2End: false, + Embedding: true, + Rerank: true, + Reasoning: true, + ListModels: true, + CountTokens: true, }, } diff --git a/core/providers/cohere/errors.go b/core/providers/cohere/errors.go index e9183b1b34..e444d86650 100644 --- a/core/providers/cohere/errors.go +++ b/core/providers/cohere/errors.go @@ -6,7 +6,7 @@ import ( "github.com/valyala/fasthttp" ) -func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp CohereError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) bifrostErr.Type = &errorResp.Type @@ -17,10 +17,5 @@ func parseCohereError(resp *fasthttp.Response, meta *providerUtils.RequestMetada if errorResp.Code != nil { bifrostErr.Error.Code = errorResp.Code } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/cohere/models.go b/core/providers/cohere/models.go index 3df2aab89a..3b285f97b6 100644 --- a/core/providers/cohere/models.go +++ b/core/providers/cohere/models.go @@ -2,8 +2,9 @@ package cohere import ( "encoding/json" - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -44,7 +45,7 @@ type CohereRerankMeta struct { Tokens *CohereTokenUsage `json:"tokens,omitempty"` } -func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -53,37 +54,39 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - includedModels := make(map[string]bool) - for _, model := range response.Models { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.Name) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.Name) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.Name, - Name: schemas.Ptr(model.Name), - ContextLength: schemas.Ptr(int(model.ContextLength)), - SupportedMethods: model.Endpoints, - }) - includedModels[model.Name] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Models { + // Cohere uses model.Name as the model identifier + for _, result := range pipeline.FilterModel(model.Name) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + ContextLength: schemas.Ptr(int(model.ContextLength)), + SupportedMethods: model.Endpoints, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/cohere/types.go b/core/providers/cohere/types.go index a4d78f9a48..8e5aa31402 100644 --- a/core/providers/cohere/types.go +++ b/core/providers/cohere/types.go @@ -9,8 +9,11 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinimumReasoningMaxTokens = 1 -const DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +const ( + MinimumReasoningMaxTokens = 1 + DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default +) + // Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request const ( cohereTokenizeMinTextLength = 1 diff --git a/core/providers/elevenlabs/elevenlabs.go b/core/providers/elevenlabs/elevenlabs.go index 8b144b63b0..bcd3e5cfc7 100644 --- a/core/providers/elevenlabs/elevenlabs.go +++ b/core/providers/elevenlabs/elevenlabs.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "io" "mime/multipart" "net/http" @@ -74,8 +73,6 @@ func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider { // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -103,10 +100,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }) + return nil, parseElevenlabsError(resp) } var elevenlabsResponse ElevenlabsListModelsResponse @@ -115,7 +109,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, return nil, bifrostErr } - response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := elevenlabsResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -188,8 +182,6 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -211,7 +203,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche endpoint = "/v1/text-to-speech/" + voice } } else { - return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil) } requestURL := provider.buildBaseSpeechRequestURL(ctx, endpoint, schemas.SpeechRequest, request) @@ -228,8 +220,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -250,26 +241,18 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Get the response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response based on whether timestamps were requested bifrostResponse := &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -282,7 +265,7 @@ func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key sche if withTimestampsRequest { var timestampResponse ElevenlabsSpeechWithTimestampsResponse if err := sonic.Unmarshal(body, ×tampResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err) } bifrostResponse.AudioBase64 = ×tampResponse.AudioBase64 @@ -326,15 +309,12 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po return nil, err } - providerName := provider.GetProviderKey() - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToElevenlabsSpeechRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -350,7 +330,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil) if request.Params == nil || request.Params.VoiceConfig == nil || request.Params.VoiceConfig.Voice == nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } req.SetRequestURI(provider.buildBaseSpeechRequestURL(ctx, "/v1/text-to-speech/"+*request.Params.VoiceConfig.Voice+"/stream", schemas.SpeechStreamRequest, request)) @@ -381,9 +361,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -392,11 +372,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Create response channel @@ -407,9 +383,9 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -426,6 +402,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po // which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer). stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) defer stopCancellation() + defer providerUtils.EnsureStreamFinalizerCalled(ctx) // read binary audio chunks from the stream // 4KB buffer for reading chunks @@ -450,7 +427,7 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", err) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) return } @@ -463,11 +440,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -486,11 +460,8 @@ func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, po Type: schemas.SpeechStreamResponseTypeDone, Audio: []byte{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -511,32 +482,30 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return nil, err } - providerName := provider.GetProviderKey() - reqBody := ToElevenlabsTranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil) } hasFile := len(reqBody.File) > 0 hasURL := reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" if hasFile && hasURL { - return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil) } if !hasFile && !hasURL { - return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil) } var body bytes.Buffer writer := multipart.NewWriter(&body) - if bifrostErr := writeTranscriptionMultipart(writer, reqBody, providerName); bifrostErr != nil { + if bifrostErr := writeTranscriptionMultipart(writer, reqBody); bifrostErr != nil { return nil, bifrostErr } contentType := writer.FormDataContentType() if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err) } req := fasthttp.AcquireRequest() @@ -567,17 +536,12 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k // Extract and set provider response headers so they're available on error paths ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, parseElevenlabsError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + return nil, parseElevenlabsError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -593,18 +557,15 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k chunks, err := parseTranscriptionResponse(responseBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } if len(chunks) == 0 { - return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil) } response := ToBifrostTranscriptionResponse(chunks) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } @@ -612,7 +573,7 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponse interface{} if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + rawResponse = string(responseBody) } response.ExtraFields.RawResponse = rawResponse } @@ -620,9 +581,9 @@ func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, k return response, nil } -func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError { +func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest) *schemas.BifrostError { if err := writer.WriteField("model_id", reqBody.ModelID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model_id field", err) } if len(reqBody.File) > 0 { @@ -632,98 +593,98 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create file field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create file field", err) } if _, err := fileWriter.Write(reqBody.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } } if reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" { if err := writer.WriteField("cloud_storage_url", *reqBody.CloudStorageURL); err != nil { - return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err) } } if reqBody.LanguageCode != nil && strings.TrimSpace(*reqBody.LanguageCode) != "" { if err := writer.WriteField("language_code", *reqBody.LanguageCode); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language_code field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language_code field", err) } } if reqBody.TagAudioEvents != nil { if err := writer.WriteField("tag_audio_events", strconv.FormatBool(*reqBody.TagAudioEvents)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err) } } if reqBody.NumSpeakers != nil && *reqBody.NumSpeakers > 0 { if err := writer.WriteField("num_speakers", strconv.Itoa(*reqBody.NumSpeakers)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err) } } if reqBody.TimestampsGranularity != nil && *reqBody.TimestampsGranularity != "" { if err := writer.WriteField("timestamps_granularity", string(*reqBody.TimestampsGranularity)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err) } } if reqBody.Diarize != nil { if err := writer.WriteField("diarize", strconv.FormatBool(*reqBody.Diarize)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarize field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarize field", err) } } if reqBody.DiarizationThreshold != nil { if err := writer.WriteField("diarization_threshold", strconv.FormatFloat(*reqBody.DiarizationThreshold, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err) } } if len(reqBody.AdditionalFormats) > 0 { payload, err := providerUtils.MarshalSorted(reqBody.AdditionalFormats) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err) } if err := writer.WriteField("additional_formats", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err) } } if reqBody.FileFormat != nil && *reqBody.FileFormat != "" { if err := writer.WriteField("file_format", string(*reqBody.FileFormat)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file_format field", err) } } if reqBody.Webhook != nil { if err := writer.WriteField("webhook", strconv.FormatBool(*reqBody.Webhook)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook field", err) } } if reqBody.WebhookID != nil && strings.TrimSpace(*reqBody.WebhookID) != "" { if err := writer.WriteField("webhook_id", *reqBody.WebhookID); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err) } } if reqBody.Temperature != nil { if err := writer.WriteField("temperature", strconv.FormatFloat(*reqBody.Temperature, 'f', -1, 64)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } if reqBody.Seed != nil { if err := writer.WriteField("seed", strconv.Itoa(*reqBody.Seed)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seed field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seed field", err) } } if reqBody.UseMultiChannel != nil { if err := writer.WriteField("use_multi_channel", strconv.FormatBool(*reqBody.UseMultiChannel)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err) } } @@ -732,16 +693,16 @@ func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTr case string: if strings.TrimSpace(v) != "" { if err := writer.WriteField("webhook_metadata", v); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } default: payload, err := providerUtils.MarshalSorted(v) if err != nil { - return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err, providerName) + return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err) } if err := writer.WriteField("webhook_metadata", string(payload)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err) } } } diff --git a/core/providers/elevenlabs/errors.go b/core/providers/elevenlabs/errors.go index 374e251958..f30807efd5 100644 --- a/core/providers/elevenlabs/errors.go +++ b/core/providers/elevenlabs/errors.go @@ -9,7 +9,7 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" ) -func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseElevenlabsError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp ElevenlabsError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if errorResp.Detail != nil { @@ -64,11 +64,6 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe Message: message, }, } - if meta != nil { - result.ExtraFields.Provider = meta.Provider - result.ExtraFields.ModelRequested = meta.Model - result.ExtraFields.RequestType = meta.RequestType - } return result } } @@ -91,10 +86,5 @@ func parseElevenlabsError(resp *fasthttp.Response, meta *providerUtils.RequestMe bifrostErr.Error.Message = message } } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/elevenlabs/models.go b/core/providers/elevenlabs/models.go index c211e85196..f762d97ee8 100644 --- a/core/providers/elevenlabs/models.go +++ b/core/providers/elevenlabs/models.go @@ -1,12 +1,13 @@ package elevenlabs import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,35 +16,36 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid Data: make([]schemas.Model, 0, len(*response)), } - includedModels := make(map[string]bool) - for _, model := range *response { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.ModelID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ModelID, - Name: schemas.Ptr(model.Name), - }) - includedModels[model.ModelID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range *response { + for _, result := range pipeline.FilterModel(model.ModelID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/elevenlabs/realtime.go b/core/providers/elevenlabs/realtime.go index f124f58339..a18e1cd514 100644 --- a/core/providers/elevenlabs/realtime.go +++ b/core/providers/elevenlabs/realtime.go @@ -39,6 +39,44 @@ func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string] return headers } +// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented. +func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool { + return false +} + +// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs. +func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) { + return "", &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"}, + } +} + +func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + return false +} + +func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string { + return "" +} + +func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string { + return "" +} + +func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + return eventType == schemas.RTEventResponseDone +} + // ElevenLabs Conversational AI WebSocket event types const ( elConversationInitMetadata = "conversation_initiation_metadata" @@ -50,8 +88,8 @@ const ( elInterruption = "interruption" elClientToolCall = "client_tool_call" - elUserAudioChunk = "user_audio_chunk" - elPong = "pong" + elUserAudioChunk = "user_audio_chunk" + elPong = "pong" elClientToolResult = "client_tool_result" elContextualUpdate = "contextual_update" ) @@ -134,7 +172,7 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra } case elAgentResponse: - event.Type = schemas.RTEventResponseTextDone + event.Type = schemas.RTEventResponseDone if raw.AgentResponse != nil { var agentResp elevenlabsTranscriptEvent if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil { @@ -194,10 +232,6 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra // ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON. func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - switch bifrostEvent.Type { case schemas.RTEventInputAudioAppend: if bifrostEvent.Delta == nil { diff --git a/core/providers/gemini/batch.go b/core/providers/gemini/batch.go index e3d92383f6..8f0405e524 100644 --- a/core/providers/gemini/batch.go +++ b/core/providers/gemini/batch.go @@ -249,8 +249,6 @@ func extractBatchIDFromName(name string) string { // downloadBatchResultsFile downloads and parses a batch results file from Gemini. // Returns the parsed result items from the JSONL file and any parse errors encountered. func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request to download the file req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -287,15 +285,12 @@ func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse JSONL content - each line is a separate JSON object diff --git a/core/providers/gemini/errors.go b/core/providers/gemini/errors.go index adf217a141..2d60a7bcd3 100644 --- a/core/providers/gemini/errors.go +++ b/core/providers/gemini/errors.go @@ -36,7 +36,7 @@ func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError { } // parseGeminiError parses Gemini error responses -func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse as []GeminiGenerationError var errorResps []GeminiGenerationError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps) @@ -62,11 +62,6 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada } // Set Message to trimmed concatenated message bifrostErr.Error.Message = message - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } @@ -80,10 +75,5 @@ func parseGeminiError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code)) bifrostErr.Error.Message = errorResp.Error.Message } - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } return bifrostErr } diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 1dd0842158..cad4216534 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -97,9 +97,7 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { // completeRequest handles the common HTTP request pattern for Gemini API calls. // When large response streaming is activated (BifrostContextKeyLargeResponseMode set in ctx), // returns (nil, nil, latency, nil) — callers must check the context flag. -func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string, meta *providerUtils.RequestMetadata) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - +func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, model string, key schemas.Key, jsonBody []byte, endpoint string) (*GenerateContentResponse, interface{}, time.Duration, map[string]string, *schemas.BifrostError) { // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -146,10 +144,10 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp, meta) + return nil, nil, latency, providerResponseHeaders, parseGeminiError(resp) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, nil, latency, providerResponseHeaders, decodeErr } @@ -161,13 +159,13 @@ func (provider *GeminiProvider) completeRequest(ctx *schemas.BifrostContext, mod // Parse Gemini's response var geminiResponse GenerateContentResponse if err := sonic.Unmarshal(body, &geminiResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var rawResponse interface{} if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { if err := sonic.Unmarshal(body, &rawResponse); err != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } } @@ -208,10 +206,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ListModelsRequest, - }) + return nil, parseGeminiError(resp) } // Parse Gemini's response @@ -227,7 +222,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key } } - response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered) + response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -282,24 +277,17 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - jsonData, err := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiChatCompletionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -312,9 +300,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -323,9 +308,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key bifrostResponse := geminiResponse.ToBifrostChatResponse() - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -363,8 +345,7 @@ func (provider *GeminiProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion request is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -450,9 +431,9 @@ func HandleGeminiChatCompletionStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -462,11 +443,7 @@ func HandleGeminiChatCompletionStream( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) respBody := append([]byte(nil), resp.Body()...) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ChatCompletionStreamRequest, - }), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, respBody, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -481,11 +458,12 @@ func HandleGeminiChatCompletionStream( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ChatCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -495,7 +473,6 @@ func HandleGeminiChatCompletionStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -557,7 +534,7 @@ func HandleGeminiChatCompletionStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ChatCompletionStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } // Process chunk using shared function @@ -572,11 +549,6 @@ func HandleGeminiChatCompletionStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -597,11 +569,6 @@ func HandleGeminiChatCompletionStream( // Convert to Bifrost stream response response, bifrostErr, isLastChunk := geminiResponse.ToBifrostChatCompletionStream(streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -613,11 +580,8 @@ func HandleGeminiChatCompletionStream( response.Model = modelName } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -692,8 +656,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return nil, fmt.Errorf("responses input is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -705,11 +668,7 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem } // Use struct directly for JSON marshaling - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -722,9 +681,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem return &schemas.BifrostResponsesResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -735,9 +691,6 @@ func (provider *GeminiProvider) Responses(ctx *schemas.BifrostContext, key schem bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -767,13 +720,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( bodyReader io.Reader, // Optional: for large payload request streaming (pass nil for normal path) bodySize int, // Required if bodyReader is non-nil ) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - meta := &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - } - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -807,14 +753,14 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Handle error response — materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - bifrostErr := parseGeminiError(resp, meta) + bifrostErr := parseGeminiError(resp) wait() fasthttp.ReleaseResponse(resp) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Delegate large response detection + normal buffered path to shared utility - responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLarge, respErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if respErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -830,9 +776,6 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( Model: request.Model, Usage: usage, } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // resp owned by reader in context — don't release wait() @@ -844,12 +787,9 @@ func (provider *GeminiProvider) responsesWithLargeResponseDetection( // Normal parse-and-convert path var geminiResponse GenerateContentResponse if unmarshalErr := sonic.Unmarshal(responseBody, &geminiResponse); unmarshalErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, unmarshalErr) } bifrostResponse := geminiResponse.ToResponsesBifrostResponsesResponse() - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -901,8 +841,7 @@ func (provider *GeminiProvider) ResponsesStream(ctx *schemas.BifrostContext, pos return nil, fmt.Errorf("responses input is not provided or could not be converted to gemini format") } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -987,9 +926,9 @@ func HandleGeminiResponsesStream( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(doErr, fasthttp.ErrTimeout) || errors.Is(doErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, doErr), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -998,11 +937,7 @@ func HandleGeminiResponsesStream( // Check for HTTP errors — use parseGeminiError to preserve upstream error details if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.ResponsesStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1017,11 +952,12 @@ func HandleGeminiResponsesStream( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1032,7 +968,6 @@ func HandleGeminiResponsesStream( bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", fmt.Errorf("provider returned an empty response"), - providerName, ) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError( @@ -1101,7 +1036,7 @@ func HandleGeminiResponsesStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } @@ -1117,11 +1052,6 @@ func HandleGeminiResponsesStream( Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) @@ -1139,11 +1069,6 @@ func HandleGeminiResponsesStream( // Convert to Bifrost responses stream response responses, bifrostErr := geminiResponse.ToBifrostResponsesStream(sequenceNumber, streamState) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1152,11 +1077,8 @@ func HandleGeminiResponsesStream( for i, response := range responses { if response != nil { response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1209,11 +1131,8 @@ func HandleGeminiResponsesStream( continue } finalResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } if postResponseConverter != nil { @@ -1258,8 +1177,7 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiEmbeddingRequest(request), nil - }, - providerName) + }) if err != nil { return nil, err } @@ -1310,17 +1228,13 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.EmbeddingRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + parsedErr := providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) wait() fasthttp.ReleaseResponse(resp) return nil, parsedErr } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { wait() fasthttp.ReleaseResponse(resp) @@ -1333,9 +1247,6 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem return &schemas.BifrostEmbeddingResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1357,12 +1268,9 @@ func (provider *GeminiProvider) Embedding(ctx *schemas.BifrostContext, key schem bifrostResponse := ToBifrostEmbeddingResponse(&geminiResponse, request.Model) if bifrostResponse == nil { return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, - fmt.Errorf("failed to convert Gemini embedding response to Bifrost format"), providerName) + fmt.Errorf("failed to convert Gemini embedding response to Bifrost format")) } - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1391,18 +1299,13 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.SpeechRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1414,9 +1317,6 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostSpeechResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.SpeechRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1428,13 +1328,10 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas. } response, convErr := geminiResponse.ToBifrostSpeechResponse(ctx) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.SpeechRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1466,16 +1363,13 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using speech-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1521,9 +1415,9 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1532,11 +1426,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.SpeechStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1553,11 +1443,12 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1597,7 +1488,7 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1617,11 +1508,6 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1672,11 +1558,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDelta, Audio: audioChunk, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -1693,11 +1576,8 @@ func (provider *GeminiProvider) SpeechStream(ctx *schemas.BifrostContext, postHo Type: schemas.SpeechStreamResponseTypeDone, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } response.BackfillParams(request) @@ -1725,18 +1605,13 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.TranscriptionRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -1748,9 +1623,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostTranscriptionResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TranscriptionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -1760,9 +1632,6 @@ func (provider *GeminiProvider) Transcription(ctx *schemas.BifrostContext, key s response := geminiResponse.ToBifrostTranscriptionResponse() // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TranscriptionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1784,16 +1653,13 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - // Prepare request body using transcription-specific function jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiTranscriptionRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1839,9 +1705,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, }, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1850,11 +1716,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1871,11 +1733,12 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1915,7 +1778,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) return } break @@ -1934,11 +1797,6 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Message: err.Error(), Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) @@ -1983,11 +1841,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, Type: schemas.TranscriptionStreamResponseTypeDelta, Delta: &deltaText, // Delta text for this chunk ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } lastChunkTime = time.Now() @@ -2010,11 +1865,8 @@ func (provider *GeminiProvider) TranscriptionStream(ctx *schemas.BifrostContext, TotalTokens: usage.TotalTokens, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex + 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex + 1, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2047,18 +1899,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageGenerationRequest(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2070,9 +1917,6 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2081,25 +1925,16 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image generation response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - provider.GetProviderKey(), ) } // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2116,16 +1951,13 @@ func (provider *GeminiProvider) ImageGeneration(ctx *schemas.BifrostContext, key // handleImagenImageGeneration handles Imagen model requests using Vertex AI endpoint with API key auth func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Prepare Imagen request body jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2165,16 +1997,11 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost // Handle error response if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Imagen response - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2182,10 +2009,7 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2197,9 +2021,6 @@ func (provider *GeminiProvider) handleImagenImageGeneration(ctx *schemas.Bifrost } // Convert to Bifrost format response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2224,8 +2045,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, err } - providerName := provider.GetProviderKey() - // Handle Imagen models using :predict endpoint if schemas.IsImagenModel(request.Model) { jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2233,8 +2052,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToImagenImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2269,15 +2087,10 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + body, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, decodeErr } @@ -2285,10 +2098,7 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem imagenRespOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -2300,9 +2110,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2321,18 +2128,13 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiImageEditRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } // Use common request function - geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent", &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }) + geminiResponse, rawResponse, latency, providerResponseHeaders, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonData, ":generateContent") if providerResponseHeaders != nil { ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders) } @@ -2344,9 +2146,6 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if isLargeResp, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool); isLargeResp { return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2355,25 +2154,16 @@ func (provider *GeminiProvider) ImageEdit(ctx *schemas.BifrostContext, key schem response, bifrostErr := geminiResponse.ToBifrostImageGenerationResponse() if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, - } return nil, bifrostErr } if response == nil { return nil, providerUtils.NewBifrostOperationError( "failed to convert Gemini image edit response", fmt.Errorf("ToBifrostImageGenerationResponse returned nil response"), - providerName, ) } // Set ExtraFields - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2405,7 +2195,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() model := bifrostReq.Model jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2414,7 +2203,6 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2447,17 +2235,13 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // use handle provider response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2473,12 +2257,9 @@ func (provider *GeminiProvider) VideoGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = model - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -2496,10 +2277,9 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s return nil, err } - providerName := provider.GetProviderKey() operationID := bifrostReq.ID - operationID = providerUtils.StripVideoIDProviderSuffix(operationID, providerName) + operationID = providerUtils.StripVideoIDProviderSuffix(operationID, provider.GetProviderKey()) // Create HTTP request req := fasthttp.AcquireRequest() @@ -2524,10 +2304,8 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + respBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response @@ -2541,12 +2319,10 @@ func (provider *GeminiProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) // Add extra fields bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { bifrostResp.ExtraFields.RawResponse = rawResponse @@ -2560,9 +2336,8 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s if err := providerUtils.CheckOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first so download behavior follows retrieve status. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2577,11 +2352,10 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), nil, - providerName, ) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte contentType := "video/mp4" @@ -2592,7 +2366,7 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded latency = time.Since(startTime) @@ -2623,17 +2397,16 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), nil, - providerName, ) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ VideoID: request.ID, @@ -2642,8 +2415,6 @@ func (provider *GeminiProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -2673,18 +2444,16 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - // Validate that either InputFileID or Requests is provided, but not both hasFileInput := request.InputFileID != "" hasInlineRequests := len(request.Requests) > 0 if !hasFileInput && !hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests must be provided", nil) } if hasFileInput && hasInlineRequests { - return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("cannot specify both input_file_id and requests", nil) } // Build the batch request with proper nested structure @@ -2717,12 +2486,12 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch if rawMessages, ok := body["messages"]; ok { messagesBytes, err := providerUtils.MarshalSorted(rawMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal messages", err) } var chatMessages []schemas.ChatMessage err = sonic.Unmarshal(messagesBytes, &chatMessages) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal messages", err) } contents, systemInstruction := convertBifrostMessagesToGemini(chatMessages) @@ -2732,11 +2501,11 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // If no "messages" key, try direct unmarshal (already in Gemini format) requestBytes, err := providerUtils.MarshalSorted(body) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal gemini request", err) } err = sonic.Unmarshal(requestBytes, &geminiReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to unmarshal gemini request", err) } } @@ -2760,7 +2529,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(batchReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create HTTP request @@ -2798,31 +2567,27 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.BatchCreateRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse the batch job response var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { provider.logger.Error("gemini batch create unmarshal error: " + err.Error()) - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonData, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Check for metadata if geminiResp.Metadata == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing metadata", nil) } // Check for batch stats if geminiResp.Metadata.BatchStats == nil { - return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("gemini batch response missing batch stats", nil) } // Calculate request counts based on response totalRequests := geminiResp.Metadata.BatchStats.RequestCount @@ -2865,9 +2630,7 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -2886,8 +2649,6 @@ func (provider *GeminiProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // batchListByKey lists batch jobs for Gemini for a single key. func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2935,26 +2696,21 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format @@ -2966,10 +2722,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key Status: ToBifrostBatchStatus(batch.Metadata.State), CreatedAt: parseGeminiTimestamp(batch.Metadata.CreateTime), OperationName: &batch.Name, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, + ExtraFields: schemas.BifrostResponseExtraFields{}, }) } @@ -2985,9 +2738,7 @@ func (provider *GeminiProvider) batchListByKey(ctx *schemas.BifrostContext, key HasMore: hasMore, NextCursor: nextCursor, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, latency, nil } @@ -3001,16 +2752,14 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch list", nil) } // Initialize serial pagination helper (Gemini uses PageToken for pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.PageToken, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3021,10 +2770,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -3056,9 +2801,7 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -3070,8 +2813,6 @@ func (provider *GeminiProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // batchRetrieveByKey retrieves a specific batch job for Gemini for a single key. func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3104,20 +2845,17 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var completedCount, failedCount int @@ -3146,9 +2884,7 @@ func (provider *GeminiProvider) batchRetrieveByKey(ctx *schemas.BifrostContext, Failed: failedCount, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3159,14 +2895,12 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch retrieve", nil) } // Try each key until we find the batch @@ -3185,8 +2919,6 @@ func (provider *GeminiProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // batchCancelByKey cancels a batch job for Gemini for a single key. func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create HTTP request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3224,15 +2956,9 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke if resp.StatusCode() == fasthttp.StatusNotFound || resp.StatusCode() == fasthttp.StatusMethodNotAllowed { // 404 could mean batch not found or cancel not supported // Return the error instead of assuming completed - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchCancelRequest, - }) + return nil, parseGeminiError(resp) } now := time.Now().Unix() @@ -3242,9 +2968,7 @@ func (provider *GeminiProvider) batchCancelByKey(ctx *schemas.BifrostContext, ke Status: schemas.BatchStatusCancelling, CancellingAt: &now, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3256,14 +2980,12 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch cancel", nil) } // Try each key until cancellation succeeds @@ -3285,8 +3007,6 @@ func (provider *GeminiProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // batches.delete indicates the client is no longer interested in the operation result. // It does not cancel the operation. If the server doesn't support this method, it returns UNIMPLEMENTED. func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) @@ -3315,10 +3035,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke } if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostBatchDeleteResponse{ @@ -3326,9 +3043,7 @@ func (provider *GeminiProvider) batchDeleteByKey(ctx *schemas.BifrostContext, ke Object: "batch", Status: schemas.BatchStatusDeleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -3341,14 +3056,12 @@ func (provider *GeminiProvider) BatchDelete(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch delete", nil) } var lastError *schemas.BifrostError @@ -3496,8 +3209,6 @@ func readNextSSEDataLine(reader *bufio.Reader, skipInlineData bool) ([]byte, err // batchResultsByKey retrieves batch results for Gemini for a single key. func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // We need to get the full batch response with results, so make the API call directly req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3531,20 +3242,17 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.BatchResultsRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiBatchJobResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Check if batch is still processing @@ -3552,7 +3260,6 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("batch %s is still processing (state: %s), results not yet available", request.BatchID, geminiResp.Metadata.State), nil, - providerName, ) } @@ -3640,9 +3347,7 @@ func (provider *GeminiProvider) batchResultsByKey(ctx *schemas.BifrostContext, k BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3661,14 +3366,12 @@ func (provider *GeminiProvider) BatchResults(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for batch results", nil) } // Try each key until we get results @@ -3692,10 +3395,8 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart request @@ -3705,14 +3406,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add file metadata as JSON metadataField, err := writer.CreateFormField("metadata") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata field", err) } metadataJSON, err := providerUtils.SetJSONField([]byte(`{}`), "file.displayName", request.Filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } if _, err := metadataField.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } // Add file content @@ -3722,14 +3423,14 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3760,15 +3461,12 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileUploadRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse response - wrapped in "file" object @@ -3776,7 +3474,7 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche File GeminiFileResponse `json:"file"` } if err := sonic.Unmarshal(body, &responseWrapper); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } geminiResp := responseWrapper.File @@ -3812,17 +3510,13 @@ func (provider *GeminiProvider) FileUpload(ctx *schemas.BifrostContext, key sche StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } // fileListByKey lists files from Gemini for a single key. func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, time.Duration, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -3859,20 +3553,17 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileListRequest, - }) + return nil, latency, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileListResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost response @@ -3881,9 +3572,7 @@ func (provider *GeminiProvider) fileListByKey(ctx *schemas.BifrostContext, key s Data: make([]schemas.FileObject, len(geminiResp.Files)), HasMore: geminiResp.NextPageToken != "", ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -3936,16 +3625,14 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch return nil, err } - providerName := provider.GetProviderKey() - if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file list", nil) } // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3956,10 +3643,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3991,9 +3674,7 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: resp.Data, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -4005,8 +3686,6 @@ func (provider *GeminiProvider) FileList(ctx *schemas.BifrostContext, keys []sch // fileRetrieveByKey retrieves file metadata from Gemini for a single key. func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4037,20 +3716,17 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileRetrieveRequest, - }) + return nil, parseGeminiError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var geminiResp GeminiFileResponse if err := sonic.Unmarshal(body, &geminiResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } var sizeBytes int64 @@ -4087,9 +3763,7 @@ func (provider *GeminiProvider) fileRetrieveByKey(ctx *schemas.BifrostContext, k StorageURI: geminiResp.URI, ExpiresAt: expiresAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4100,14 +3774,12 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file retrieve", nil) } // Try each key until we find the file @@ -4127,8 +3799,6 @@ func (provider *GeminiProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // fileDeleteByKey deletes a file from Gemini for a single key. func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -4159,10 +3829,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key // Handle error response - DELETE returns 200 with empty body on success if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusNoContent { - return nil, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.FileDeleteRequest, - }) + return nil, parseGeminiError(resp) } return &schemas.BifrostFileDeleteResponse{ @@ -4170,9 +3837,7 @@ func (provider *GeminiProvider) fileDeleteByKey(ctx *schemas.BifrostContext, key Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -4183,14 +3848,12 @@ func (provider *GeminiProvider) FileDelete(ctx *schemas.BifrostContext, keys []s return nil, err } - providerName := provider.GetProviderKey() - if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } if len(keys) == 0 { - return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided for file delete", nil) } // Try each key until deletion succeeds @@ -4216,14 +3879,11 @@ func (provider *GeminiProvider) FileContent(ctx *schemas.BifrostContext, keys [] return nil, err } - providerName := provider.GetProviderKey() - // Gemini doesn't support direct file content download // Files are referenced by their URI in requests return nil, providerUtils.NewBifrostOperationError( "Gemini Files API doesn't support direct content download. Use the file URI in your requests instead.", nil, - providerName, ) } @@ -4254,7 +3914,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return ToGeminiResponsesRequest(request) }, - provider.GetProviderKey(), ) if bifrostErr != nil { return nil, bifrostErr @@ -4266,14 +3925,13 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch jsonData, _ = providerUtils.DeleteJSONField(jsonData, "systemInstruction") } - providerName := provider.GetProviderKey() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) if strings.TrimSpace(request.Model) == "" { - return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model"), providerName) + return nil, providerUtils.NewBifrostOperationError("model is required for Gemini count tokens request", fmt.Errorf("missing model")) } // Determine native model name (e.g., parse any provider prefix) @@ -4306,15 +3964,12 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch } if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.CountTokensRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseGeminiError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } responseBody := append([]byte(nil), body...) @@ -4334,9 +3989,6 @@ func (provider *GeminiProvider) CountTokens(ctx *schemas.BifrostContext, key sch response := geminiResponse.ToBifrostCountTokensResponse(request.Model) // Set ExtraFields - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.CountTokensRequest response.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -4439,7 +4091,7 @@ func (provider *GeminiProvider) Passthrough( headers := providerUtils.ExtractProviderResponseHeaders(resp) body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -4453,9 +4105,6 @@ func (provider *GeminiProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4519,9 +4168,9 @@ func (provider *GeminiProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := providerUtils.ExtractProviderResponseHeaders(resp) @@ -4532,7 +4181,6 @@ func (provider *GeminiProvider) PassthroughStream( return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), ) } @@ -4544,11 +4192,7 @@ func (provider *GeminiProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -4557,11 +4201,12 @@ func (provider *GeminiProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -4610,7 +4255,7 @@ func (provider *GeminiProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/gemini/images.go b/core/providers/gemini/images.go index b390537b3e..881988392a 100644 --- a/core/providers/gemini/images.go +++ b/core/providers/gemini/images.go @@ -408,10 +408,10 @@ func ToGeminiImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRe // Handle size conversion if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" { - imageSize, aspectRatio := convertSizeToImagenFormat(*bifrostReq.Params.Size) + aspectRatio, imageSize := utils.ConvertSizeToAspectRatioAndResolution(*bifrostReq.Params.Size) if imageSize != "" && aspectRatio != "" { geminiReq.GenerationConfig.ImageConfig = &GeminiImageConfig{ - ImageSize: imageSize, + ImageSize: strings.ToLower(imageSize), AspectRatio: aspectRatio, } } @@ -513,9 +513,10 @@ func ToImagenImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRe // Handle size conversion if bifrostReq.Params.Size != nil && strings.ToLower(*bifrostReq.Params.Size) != "auto" { - imageSize, aspectRatio := convertSizeToImagenFormat(*bifrostReq.Params.Size) + aspectRatio, imageSize := utils.ConvertSizeToAspectRatioAndResolution(*bifrostReq.Params.Size) if imageSize != "" { - req.Parameters.SampleImageSize = &imageSize + imageSizeLower := strings.ToLower(imageSize) + req.Parameters.SampleImageSize = &imageSizeLower } if aspectRatio != "" { req.Parameters.AspectRatio = &aspectRatio @@ -638,55 +639,6 @@ func convertOutputFormatToMimeType(outputFormat string) string { } } -// convertSizeToImagenFormat converts standard size format (e.g., "1024x1024") to Imagen format -// Returns (imageSize, aspectRatio) where imageSize is "1k", "2k", "4k" and aspectRatio is one of: -// "1:1", "3:4", "4:3", "9:16", or "16:9" -func convertSizeToImagenFormat(size string) (string, string) { - // Parse size string (format: "WIDTHxHEIGHT") - parts := strings.Split(size, "x") - if len(parts) != 2 { - return "", "" - } - - width, err1 := strconv.Atoi(parts[0]) - height, err2 := strconv.Atoi(parts[1]) - if err1 != nil || err2 != nil { - return "", "" - } - - // Validate width and height are positive integers - if width <= 0 || height <= 0 { - return "", "" - } - - var imageSize string - if width <= 1024 && height <= 1024 { - imageSize = "1k" - } else if width <= 2048 && height <= 2048 { - imageSize = "2k" - } else if width <= 4096 && height <= 4096 { - imageSize = "4k" - } - - // Calculate aspect ratio - var aspectRatio string - ratio := float64(width) / float64(height) - - // Common aspect ratios with tolerance - if ratio >= 0.99 && ratio <= 1.01 { - aspectRatio = "1:1" - } else if ratio >= 0.74 && ratio <= 0.76 { - aspectRatio = "3:4" - } else if ratio >= 1.32 && ratio <= 1.34 { - aspectRatio = "4:3" - } else if ratio >= 0.56 && ratio <= 0.57 { - aspectRatio = "9:16" - } else if ratio >= 1.77 && ratio <= 1.78 { - aspectRatio = "16:9" - } - - return imageSize, aspectRatio -} // ToBifrostImageGenerationResponse converts an Imagen response to Bifrost format func (response *GeminiImagenResponse) ToBifrostImageGenerationResponse() *schemas.BifrostImageGenerationResponse { diff --git a/core/providers/gemini/models.go b/core/providers/gemini/models.go index 4c8f83c364..7b9f6410eb 100644 --- a/core/providers/gemini/models.go +++ b/core/providers/gemini/models.go @@ -1,9 +1,9 @@ package gemini import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) @@ -17,7 +17,7 @@ func toGeminiModelResourceName(modelID string) string { return "models/" + modelID } -func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -26,45 +26,47 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Models)), } - includedModels := make(map[string]bool) - for _, model := range response.Models { + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + included := make(map[string]bool) + + for _, model := range response.Models { contextLength := model.InputTokenLimit + model.OutputTokenLimit - // Remove prefix models/ from model.Name + // Gemini returns model names with a "models/" prefix — strip it before filtering + // so that allowedModels entries like "gemini-1.5-pro" match correctly. modelName := strings.TrimPrefix(model.Name, "models/") - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelName) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, modelName) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + modelName, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - ContextLength: schemas.Ptr(int(contextLength)), - MaxInputTokens: schemas.Ptr(model.InputTokenLimit), - MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), - SupportedMethods: model.SupportedGenerationMethods, - }) - includedModels[modelName] = true - } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + for _, result := range pipeline.FilterModel(modelName) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + ContextLength: schemas.Ptr(int(contextLength)), + MaxInputTokens: schemas.Ptr(model.InputTokenLimit), + MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit), + SupportedMethods: model.SupportedGenerationMethods, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/gemini/types.go b/core/providers/gemini/types.go index 29935ca23b..a30eadcaf5 100644 --- a/core/providers/gemini/types.go +++ b/core/providers/gemini/types.go @@ -17,11 +17,13 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -const MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level -const DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation -const DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini -const DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini -const skipThoughtSignatureValidator = "skip_thought_signature_validator" +const ( + MinReasoningMaxTokens = 1 // Minimum max tokens for reasoning - used for estimation of effort level + DefaultCompletionMaxTokens = 8192 // Default max output tokens for Gemini - used for relative reasoning max token calculation + DefaultReasoningMinBudget = 1024 // Default minimum reasoning budget for Gemini + DynamicReasoningBudget = -1 // Special value for dynamic reasoning budget in Gemini + skipThoughtSignatureValidator = "skip_thought_signature_validator" +) type thinkingBudgetRange struct { Min int @@ -523,8 +525,7 @@ type GoogleMaps struct { } // URLContext is a tool to support URL context retrieval. -type URLContext struct { -} +type URLContext struct{} // ToolComputerUse is a tool to support computer use. type ToolComputerUse struct { @@ -569,8 +570,7 @@ type ExternalAPIElasticSearchParams struct { } // ExternalAPISimpleSearchParams represents the search parameters to use for SIMPLE_SEARCH spec. -type ExternalAPISimpleSearchParams struct { -} +type ExternalAPISimpleSearchParams struct{} // ExternalAPI retrieves from data source powered by external API for grounding. The external API // is not owned by Google, but needs to follow the pre-defined API spec. @@ -728,8 +728,7 @@ type Retrieval struct { // ToolCodeExecution is a tool that executes code generated by the model, and automatically returns the result // to the model. See also [ExecutableCode]and [CodeExecutionResult] which are input // and output to this tool. -type ToolCodeExecution struct { -} +type ToolCodeExecution struct{} // Tool details of a tool that the model may use to generate a response. type Tool struct { diff --git a/core/providers/gemini/utils.go b/core/providers/gemini/utils.go index afb410d022..ae4339db0e 100644 --- a/core/providers/gemini/utils.go +++ b/core/providers/gemini/utils.go @@ -85,7 +85,7 @@ func effortToThinkingLevel(effort string, model string) string { return "high" // Pro models don't support medium, use high } return "medium" - case "high": + case "high", "xhigh", "max": return "high" default: if isPro { diff --git a/core/providers/gemini/videos.go b/core/providers/gemini/videos.go index 62ce110c26..43ece90be4 100644 --- a/core/providers/gemini/videos.go +++ b/core/providers/gemini/videos.go @@ -217,7 +217,7 @@ func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if operation == nil { - return nil, providerUtils.NewBifrostOperationError("operation is nil", nil, schemas.Gemini) + return nil, providerUtils.NewBifrostOperationError("operation is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/groq/groq.go b/core/providers/groq/groq.go index c1362fbece..f152a10e94 100644 --- a/core/providers/groq/groq.go +++ b/core/providers/groq/groq.go @@ -149,9 +149,6 @@ func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/groq/groq_test.go b/core/providers/groq/groq_test.go index 54b7945d76..3bf59588db 100644 --- a/core/providers/groq/groq_test.go +++ b/core/providers/groq/groq_test.go @@ -38,28 +38,28 @@ func TestGroq(t *testing.T) { TranscriptionModel: "whisper-large-v3", SpeechSynthesisModel: "canopylabs/orpheus-v1-english", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, - TextCompletionStream: false, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, + TextCompletionStream: false, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - FileBase64: false, // Not supported - FileURL: false, // Not supported - CompleteEnd2End: true, - Embedding: false, - ListModels: true, - Reasoning: true, - Transcription: true, - SpeechSynthesis: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + FileBase64: false, // Not supported + FileURL: false, // Not supported + CompleteEnd2End: true, + Embedding: false, + ListModels: true, + Reasoning: true, + Transcription: true, + SpeechSynthesis: true, }, } t.Run("GroqTests", func(t *testing.T) { diff --git a/core/providers/huggingface/errors.go b/core/providers/huggingface/errors.go index 49ce427df7..d98357e0a8 100644 --- a/core/providers/huggingface/errors.go +++ b/core/providers/huggingface/errors.go @@ -10,7 +10,7 @@ import ( ) // parseHuggingFaceImageError parses HuggingFace error responses -func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseHuggingFaceImageError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp HuggingFaceResponseError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -53,13 +53,5 @@ func parseHuggingFaceImageError(resp *fasthttp.Response, meta *providerUtils.Req bifrostErr.Error.Message = errorResp.Error } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/huggingface/huggingface.go b/core/providers/huggingface/huggingface.go index 1fed844d31..110f5d6574 100644 --- a/core/providers/huggingface/huggingface.go +++ b/core/providers/huggingface/huggingface.go @@ -254,12 +254,12 @@ func (provider *HuggingFaceProvider) completeRequest(ctx *schemas.BifrostContext // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp, nil) + return nil, latency, providerResponseHeaders, parseHuggingFaceImageError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -325,7 +325,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName)} + resultsChan <- providerResult{provider: inferProvider, err: providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)} return } @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext } if result.response != nil { - providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered) + providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerResponse != nil { aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...) totalLatency += result.latency @@ -459,10 +459,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionRequest, - }, } } if inferenceProvider != "" { @@ -483,8 +479,7 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, reqBody.Stream = schemas.Ptr(false) } return reqBody, nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -518,9 +513,6 @@ func (provider *HuggingFaceProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse.Object = "chat.completion" } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -550,10 +542,6 @@ func (provider *HuggingFaceProvider) ChatCompletionStream(ctx *schemas.BifrostCo Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ChatCompletionStreamRequest, - }, } } if inferenceProvider != "" { @@ -610,9 +598,6 @@ func (provider *HuggingFaceProvider) Responses(ctx *schemas.BifrostContext, key } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -644,10 +629,6 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.EmbeddingRequest, - }, } } @@ -657,8 +638,7 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceEmbeddingRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -698,13 +678,10 @@ func (provider *HuggingFaceProvider) Embedding(ctx *schemas.BifrostContext, key // Unmarshal directly to BifrostEmbeddingResponse with custom logic bifrostResponse, convErr := UnmarshalHuggingFaceEmbeddingResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -735,10 +712,6 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.SpeechRequest, - }, } } @@ -747,8 +720,7 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceSpeechRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -784,18 +756,15 @@ func (provider *HuggingFaceProvider) Speech(ctx *schemas.BifrostContext, key sch // Download the audio file from the URL audioData, downloadErr := provider.downloadAudioFromURL(ctx, response.Audio.URL) if downloadErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, downloadErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse, convErr := response.ToBifrostSpeechResponse(request.Model, audioData) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.SpeechRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -838,10 +807,6 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.TranscriptionRequest, - }, } } @@ -851,7 +816,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, isHFInferenceAudioRequest := inferenceProvider == hfInference if inferenceProvider == hfInference { if request.Input == nil || len(request.Input.File) == 0 { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests"), provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderCreateRequest, fmt.Errorf("input file data is required for hf-inference transcription requests")) } jsonData = request.Input.File } else { @@ -861,8 +826,7 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceTranscriptionRequest(request) - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -910,13 +874,10 @@ func (provider *HuggingFaceProvider) Transcription(ctx *schemas.BifrostContext, bifrostResponse, convErr := response.ToBifrostTranscriptionResponse(request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.TranscriptionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -950,10 +911,6 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationRequest, - }, } } @@ -963,8 +920,7 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageGenerationRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -1004,15 +960,12 @@ func (provider *HuggingFaceProvider) ImageGeneration(ctx *schemas.BifrostContext // Unmarshal response using Nebius converter bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1044,10 +997,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageGenerationStreamRequest, - }, } } @@ -1055,11 +1004,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image generation streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } - providerName := provider.GetProviderKey() // Set headers headers := map[string]string{ @@ -1077,8 +1023,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageStreamRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1110,9 +1055,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1128,9 +1070,9 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1139,11 +1081,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1160,15 +1098,14 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer providerUtils.ReleaseStreamingResponse(resp) defer close(responseChan) if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1189,6 +1126,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1207,14 +1146,7 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1237,11 +1169,6 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1267,11 +1194,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageGenerationEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1311,11 +1235,8 @@ func (provider *HuggingFaceProvider) ImageGenerationStream(ctx *schemas.BifrostC Type: schemas.ImageGenerationEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ @@ -1359,10 +1280,6 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditRequest, - }, } } @@ -1377,8 +1294,7 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { req, err := ToHuggingFaceImageEditRequest(request) return req, err - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -1414,15 +1330,12 @@ func (provider *HuggingFaceProvider) ImageEdit(ctx *schemas.BifrostContext, key // Unmarshal response bifrostResponse, convErr := UnmarshalHuggingFaceImageGenerationResponse(responseBody, request.Model) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey()) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.Created = time.Now().Unix() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1454,10 +1367,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Message: nameErr.Error(), Error: nameErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - RequestType: schemas.ImageEditStreamRequest, - }, } } @@ -1465,9 +1374,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext if inferenceProvider != falAI { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("image edit streaming is only supported for fal-ai inference provider, got: %s", inferenceProvider), - nil, - provider.GetProviderKey(), - ) + nil) } var authHeader map[string]string @@ -1493,15 +1400,13 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - providerName := provider.GetProviderKey() jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToHuggingFaceImageEditRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1529,9 +1434,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext req.SetBody(jsonBody) } - // Capture start time before making the HTTP request for latency calculation - startTime := time.Now() - // Make the request err := provider.client.Do(req, resp) if err != nil { @@ -1547,9 +1449,9 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -1558,11 +1460,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseHuggingFaceImageError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1579,15 +1477,14 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer providerUtils.ReleaseStreamingResponse(resp) defer close(responseChan) if resp.BodyStream() == nil { bifrostErr := providerUtils.NewBifrostOperationError( "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - providerName, - ) + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1608,6 +1505,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext sseReader := providerUtils.GetSSEDataReader(ctx, reader) + // Initialize latency timers post-handshake so chunk latency reflects pure streaming time. + startTime := time.Now() lastChunkTime := startTime chunkIndex := 0 var lastB64Data, lastURLData, lastJsonData string @@ -1626,14 +1525,7 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext } bifrostErr := providerUtils.NewBifrostOperationError( fmt.Sprintf("Error reading fal-ai stream: %v", readErr), - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + readErr) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -1656,11 +1548,6 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Error: &schemas.ErrorField{ Message: errorResp.Message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } if errorResp.Error != "" { bifrostErr.Error.Message = errorResp.Error @@ -1686,11 +1573,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext chunk := &schemas.BifrostImageGenerationStreamResponse{ Type: schemas.ImageEditEventTypePartial, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1730,11 +1614,8 @@ func (provider *HuggingFaceProvider) ImageEditStream(ctx *schemas.BifrostContext Type: schemas.ImageEditEventTypeCompleted, Index: lastIndex, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } finalChunk.BackfillParams(&schemas.BifrostRequest{ diff --git a/core/providers/huggingface/huggingface_test.go b/core/providers/huggingface/huggingface_test.go index 4a1a91802d..4c98cec1f0 100644 --- a/core/providers/huggingface/huggingface_test.go +++ b/core/providers/huggingface/huggingface_test.go @@ -51,7 +51,7 @@ func TestHuggingface(t *testing.T) { ImageBase64: true, MultipleImages: true, CompleteEnd2End: true, - Embedding: true, + Embedding: false, Transcription: true, TranscriptionStream: false, SpeechSynthesis: true, diff --git a/core/providers/huggingface/models.go b/core/providers/huggingface/models.go index c637c3b6f6..de615ccec2 100644 --- a/core/providers/huggingface/models.go +++ b/core/providers/huggingface/models.go @@ -5,6 +5,7 @@ import ( "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -13,7 +14,7 @@ const ( maxModelFetchLimit = 1000 ) -func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, inferenceProvider inferenceProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -22,15 +23,20 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi Data: make([]schemas.Model, 0, len(response.Models)), } - var blacklisted map[string]struct{} - if !unfiltered && len(blacklistedModels) > 0 { - blacklisted = make(map[string]struct{}, len(blacklistedModels)) - for _, m := range blacklistedModels { - blacklisted[m] = struct{}{} - } + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - includedModels := make(map[string]bool) + included := make(map[string]bool) + for _, model := range response.Models { if model.ModelID == "" { continue @@ -41,39 +47,33 @@ func (response *HuggingFaceListModelsResponse) ToBifrostListModelsResponse(provi continue } - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ModelID) { - continue - } - if _, ok := blacklisted[model.ModelID]; ok { - continue - } - - newModel := schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, model.ModelID), - Name: schemas.Ptr(model.ModelID), - SupportedMethods: supported, - HuggingFaceID: schemas.Ptr(model.ID), - } - - bifrostResponse.Data = append(bifrostResponse.Data, newModel) - includedModels[model.ModelID] = true - } - - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if _, ok := blacklisted[allowedModel]; ok { - continue + // Aliases apply at the model level (model.ModelID), not at the compound + // "{providerKey}/{inferenceProvider}/{modelID}" level. + for _, result := range pipeline.FilterModel(model.ModelID) { + newModel := schemas.Model{ + // inferenceProvider stays in the compound ID; aliases rename only the model segment + ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, result.ResolvedID), + Name: schemas.Ptr(model.ModelID), + SupportedMethods: supported, + HuggingFaceID: schemas.Ptr(model.ID), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, allowedModel), - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + newModel.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, newModel) + included[strings.ToLower(result.ResolvedID)] = true } } + // Backfill: use standard pipeline. Note that backfilled HF entries use a simplified + // compound ID since we don't know which inferenceProvider to assign them to. + for _, m := range pipeline.BackfillModels(included) { + // Re-wrap the backfill ID to include the inferenceProvider segment + rawID := strings.TrimPrefix(m.ID, string(providerKey)+"/") + m.ID = fmt.Sprintf("%s/%s/%s", providerKey, inferenceProvider, rawID) + bifrostResponse.Data = append(bifrostResponse.Data, m) + } + return bifrostResponse } diff --git a/core/providers/huggingface/responses.go b/core/providers/huggingface/responses.go index fd68aa76a8..35ad2c336d 100644 --- a/core/providers/huggingface/responses.go +++ b/core/providers/huggingface/responses.go @@ -43,9 +43,6 @@ func ToBifrostResponsesResponseFromHuggingFace(resp *schemas.BifrostChatResponse responsesResp := resp.ToBifrostResponsesResponse() if responsesResp != nil { - responsesResp.ExtraFields.Provider = schemas.HuggingFace - responsesResp.ExtraFields.ModelRequested = requestedModel - responsesResp.ExtraFields.RequestType = schemas.ResponsesRequest } return responsesResp, nil diff --git a/core/providers/huggingface/speech.go b/core/providers/huggingface/speech.go index 65c0ba6e12..f702d1f39f 100644 --- a/core/providers/huggingface/speech.go +++ b/core/providers/huggingface/speech.go @@ -125,10 +125,6 @@ func (response *HuggingFaceSpeechResponse) ToBifrostSpeechResponse(requestedMode // Create the base Bifrost response with the downloaded audio data bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: audioData, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Note: HuggingFace TTS API typically doesn't return usage information diff --git a/core/providers/huggingface/transcription.go b/core/providers/huggingface/transcription.go index 0d892cb07c..f3ff5c293a 100644 --- a/core/providers/huggingface/transcription.go +++ b/core/providers/huggingface/transcription.go @@ -144,10 +144,6 @@ func (response *HuggingFaceTranscriptionResponse) ToBifrostTranscriptionResponse // Create the base Bifrost response bifrostResponse := &schemas.BifrostTranscriptionResponse{ Text: response.Text, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.HuggingFace, - ModelRequested: requestedModel, - }, } // Map chunks to segments if available diff --git a/core/providers/huggingface/utils.go b/core/providers/huggingface/utils.go index ad68e4c6ac..b96210c832 100644 --- a/core/providers/huggingface/utils.go +++ b/core/providers/huggingface/utils.go @@ -221,8 +221,6 @@ func convertToInferenceProviderMappings(resp *HuggingFaceInferenceProviderMappin } func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx context.Context, huggingfaceModelName string) (map[inferenceProvider]HuggingFaceInferenceProviderMapping, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Check cache first if cached, ok := provider.modelProviderMappingCache.Load(huggingfaceModelName); ok { if mappings, ok := cached.(map[inferenceProvider]HuggingFaceInferenceProviderMapping); ok { @@ -259,12 +257,12 @@ func (provider *HuggingFaceProvider) getModelInferenceProviderMapping(ctx contex body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var mappingResp HuggingFaceInferenceProviderMappingResponse if err := sonic.Unmarshal(body, &mappingResp); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } mappings := convertToInferenceProviderMappings(&mappingResp) diff --git a/core/providers/mistral/custom_provider_test.go b/core/providers/mistral/custom_provider_test.go index 0d8e283f51..9015230544 100644 --- a/core/providers/mistral/custom_provider_test.go +++ b/core/providers/mistral/custom_provider_test.go @@ -25,7 +25,7 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { resp.SetStatusCode(http.StatusBadRequest) resp.SetBodyString(`{"message":"invalid request","type":"invalid_request_error","code":"bad_request"}`) - bifrostErr := ParseMistralError(resp, schemas.OCRRequest, customMistralProviderName, "mistral-ocr-latest") + bifrostErr := ParseMistralError(resp) require.NotNil(t, bifrostErr) require.NotNil(t, bifrostErr.Error) @@ -33,8 +33,6 @@ func TestParseMistralError_UsesExportedConverterMetadata(t *testing.T) { assert.Equal(t, schemas.Ptr("invalid_request_error"), bifrostErr.Error.Type) assert.Equal(t, schemas.Ptr("bad_request"), bifrostErr.Error.Code) assert.Equal(t, customMistralProviderName, bifrostErr.ExtraFields.Provider) - assert.Equal(t, schemas.OCRRequest, bifrostErr.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", bifrostErr.ExtraFields.ModelRequested) } func TestMistralProvider_CustomAliasChatStreamUsesBaseCompatibilityAndAliasMetadata(t *testing.T) { @@ -156,5 +154,5 @@ func TestMistralProvider_CustomAliasEmbeddingReportsAliasMetadata(t *testing.T) require.NotNil(t, response) assert.Equal(t, customMistralProviderName, response.ExtraFields.Provider) - assert.Equal(t, "codestral-embed", response.ExtraFields.ModelRequested) + assert.Equal(t, "codestral-embed", response.ExtraFields.ResolvedModelUsed) } diff --git a/core/providers/mistral/errors.go b/core/providers/mistral/errors.go index 2ae260eb9a..cbbd40a560 100644 --- a/core/providers/mistral/errors.go +++ b/core/providers/mistral/errors.go @@ -19,7 +19,7 @@ type MistralErrorResponse struct { } // ParseMistralError parses Mistral-specific error responses. -func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseMistralError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp MistralErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) if bifrostErr == nil { @@ -67,9 +67,5 @@ func ParseMistralError(resp *fasthttp.Response, requestType schemas.RequestType, } } - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/mistral/mistral.go b/core/providers/mistral/mistral.go index 597bf6c239..1999cbb5fb 100644 --- a/core/providers/mistral/mistral.go +++ b/core/providers/mistral/mistral.go @@ -70,14 +70,12 @@ func NewMistralProvider(config *schemas.ProviderConfig, logger schemas.Logger) * // GetProviderKey returns the provider identifier for Mistral. func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { - return schemas.Mistral + return providerUtils.GetProviderName(schemas.Mistral, provider.customProviderConfig) } // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -103,7 +101,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseMistralError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseMistralError(resp) return nil, bifrostErr } @@ -118,7 +116,7 @@ func (provider *MistralProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } // Create final response - response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels) + response := mistralResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) response.ExtraFields.Latency = latency.Milliseconds() @@ -158,13 +156,27 @@ func (provider *MistralProvider) TextCompletionStream(ctx *schemas.BifrostContex return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey()) } +// normalizeChatRequestForConversion returns the request unchanged for the stock Mistral +// provider. For custom aliases (e.g. a provider registered as "custom-mistral" with +// BaseProviderType=Mistral), it returns a shallow copy with Provider set to schemas.Mistral +// so the shared OpenAI converter applies Mistral-specific compatibility (max_completion_tokens +// → max_tokens, tool_choice struct → "any"). The caller's request is never mutated. +func (provider *MistralProvider) normalizeChatRequestForConversion(request *schemas.BifrostChatRequest) *schemas.BifrostChatRequest { + if request == nil || provider.customProviderConfig == nil || request.Provider == schemas.Mistral { + return request + } + normalized := *request + normalized.Provider = schemas.Mistral + return &normalized +} + // ChatCompletion performs a chat completion request to the Mistral API. func (provider *MistralProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), - request, + provider.normalizeChatRequestForConversion(request), key, provider.networkConfig.ExtraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -190,7 +202,7 @@ func (provider *MistralProvider) ChatCompletionStream(ctx *schemas.BifrostContex ctx, provider.client, provider.networkConfig.BaseURL+"/v1/chat/completions", - request, + provider.normalizeChatRequestForConversion(request), authHeader, provider.networkConfig.ExtraHeaders, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -214,9 +226,6 @@ func (provider *MistralProvider) Responses(ctx *schemas.BifrostContext, key sche } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -264,25 +273,23 @@ func (provider *MistralProvider) Rerank(ctx *schemas.BifrostContext, key schemas // OCR performs an OCR request to the Mistral API. // It sends a JSON request to Mistral's OCR endpoint and returns the extracted content. func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Convert Bifrost request to Mistral format mistralReq := ToMistralOCRRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("ocr request input is not provided", nil) } // Marshal request body requestBody, err := sonic.Marshal(mistralReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge extra params into JSON payload if len(mistralReq.ExtraParams) > 0 { requestBody, err = providerUtils.MergeExtraParamsIntoJSON(requestBody, mistralReq.ExtraParams) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -318,13 +325,12 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.OCRRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -352,20 +358,17 @@ func (provider *MistralProvider) OCR(ctx *schemas.BifrostContext, key schemas.Ke }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostOCRResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert ocr response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.OCRRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -387,16 +390,14 @@ func (provider *MistralProvider) SpeechStream(ctx *schemas.BifrostContext, postH // It creates a multipart form with the audio file and sends it to Mistral's transcription endpoint. // Returns the transcribed text and metadata, or an error if the request fails. func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form body - body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, providerName) + body, contentType, bifrostErr := createMistralTranscriptionMultipartBody(mistralReq, provider.GetProviderKey()) if bifrostErr != nil { return nil, bifrostErr } @@ -428,13 +429,12 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -462,20 +462,17 @@ func (provider *MistralProvider) Transcription(ctx *schemas.BifrostContext, key }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } // Convert to Bifrost format response := mistralResponse.ToBifrostTranscriptionResponse() if response == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert transcription response", nil) } // Set extra fields response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.RequestType = schemas.TranscriptionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { @@ -497,7 +494,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Convert Bifrost request to Mistral format mistralReq := ToMistralTranscriptionRequest(request) if mistralReq == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } mistralReq.Stream = schemas.Ptr(true) @@ -552,9 +549,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -563,8 +560,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseMistralError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseMistralError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -583,9 +579,9 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext go func() { defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -602,6 +598,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext // which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer). stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger) defer stopCancellation() + defer providerUtils.EnsureStreamFinalizerCalled(ctx) sseReader := providerUtils.GetSSEEventReader(ctx, reader) chunkIndex := -1 @@ -624,7 +621,7 @@ func (provider *MistralProvider) TranscriptionStream(ctx *schemas.BifrostContext } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, provider.logger) } break } @@ -672,11 +669,6 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, &bifrostErr, responseChan, provider.logger) return @@ -705,11 +697,8 @@ func (provider *MistralProvider) processTranscriptionStreamEvent( // Set extra fields response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: model, - ChunkIndex: chunkIndex, - Latency: time.Since(*lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(*lastChunkTime).Milliseconds(), } *lastChunkTime = time.Now() diff --git a/core/providers/mistral/models.go b/core/providers/mistral/models.go index ef3e5934c1..8d5fd7f3d6 100644 --- a/core/providers/mistral/models.go +++ b/core/providers/mistral/models.go @@ -1,12 +1,13 @@ package mistral import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { +func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -15,40 +16,40 @@ func (response *MistralListModelsResponse) ToBifrostListModelsResponse(allowedMo Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { - continue - } - if slices.Contains(blacklistedModels, model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + model.ID, - Name: schemas.Ptr(model.Name), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.Created), - ContextLength: schemas.Ptr(int(model.MaxContextLength)), - OwnedBy: schemas.Ptr(model.OwnedBy), - }) - includedModels[model.ID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Mistral, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(schemas.Mistral) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.Name), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.Created), + ContextLength: schemas.Ptr(int(model.MaxContextLength)), + OwnedBy: schemas.Ptr(model.OwnedBy), } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(schemas.Mistral) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - includedModels[allowedModel] = true + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/mistral/ocr_test.go b/core/providers/mistral/ocr_test.go index c88bca5f48..ccf7223e4c 100644 --- a/core/providers/mistral/ocr_test.go +++ b/core/providers/mistral/ocr_test.go @@ -436,9 +436,6 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, 1, resp.Pages[1].Index) require.NotNil(t, resp.UsageInfo) assert.Equal(t, 2, resp.UsageInfo.PagesProcessed) - assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) - assert.Equal(t, "mistral-ocr-latest", resp.ExtraFields.ModelRequested) }, }, { @@ -503,9 +500,6 @@ func TestOCRWithMockServer(t *testing.T) { assert.Equal(t, "server_error", *err.Error.Type) require.NotNil(t, err.Error.Code) assert.Equal(t, "internal_error", *err.Error.Code) - assert.Equal(t, schemas.Mistral, err.ExtraFields.Provider) - assert.Equal(t, schemas.OCRRequest, err.ExtraFields.RequestType) - assert.Equal(t, "mistral-ocr-latest", err.ExtraFields.ModelRequested) }, }, { @@ -757,7 +751,5 @@ func TestMistralOCRIntegration(t *testing.T) { require.NotEmpty(t, resp.Pages, "Expected at least one page") assert.Equal(t, 0, resp.Pages[0].Index) assert.NotEmpty(t, resp.Pages[0].Markdown, "Expected non-empty markdown for page 0") - assert.Equal(t, schemas.OCRRequest, resp.ExtraFields.RequestType) - assert.Equal(t, schemas.Mistral, resp.ExtraFields.Provider) assert.Greater(t, resp.ExtraFields.Latency, int64(0)) } diff --git a/core/providers/mistral/transcription.go b/core/providers/mistral/transcription.go index a4a018e5c6..fe9b262126 100644 --- a/core/providers/mistral/transcription.go +++ b/core/providers/mistral/transcription.go @@ -109,58 +109,58 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, req *Mi } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(req.File); err != nil { - return providerUtils.NewBifrostOperationError("failed to write file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write file data", err) } // Add model field (required) if err := writer.WriteField("model", req.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add stream field if streaming if req.Stream != nil && *req.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } // Add optional fields if req.Language != nil { if err := writer.WriteField("language", *req.Language); err != nil { - return providerUtils.NewBifrostOperationError("failed to write language field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write language field", err) } } if req.Prompt != nil { if err := writer.WriteField("prompt", *req.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } } if req.ResponseFormat != nil { if err := writer.WriteField("response_format", *req.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if req.Temperature != nil { if err := writer.WriteField("temperature", formatFloat64(*req.Temperature)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range req.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } // Close the multipart writer to finalize the form if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/nebius/errors.go b/core/providers/nebius/errors.go index 98d0fb78d8..de8bcf0d84 100644 --- a/core/providers/nebius/errors.go +++ b/core/providers/nebius/errors.go @@ -9,7 +9,7 @@ import ( ) // parseNebiusImageError parses Nebius error responses -func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseNebiusImageError(resp *fasthttp.Response) *schemas.BifrostError { var nebiusErr NebiusError bifrostErr := providerUtils.HandleProviderAPIError(resp, &nebiusErr) @@ -60,13 +60,5 @@ func parseNebiusImageError(resp *fasthttp.Response, meta *providerUtils.RequestM bifrostErr.Error.Message = message } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } - return bifrostErr } diff --git a/core/providers/nebius/nebius.go b/core/providers/nebius/nebius.go index 998588ce71..eac617df6e 100644 --- a/core/providers/nebius/nebius.go +++ b/core/providers/nebius/nebius.go @@ -193,9 +193,6 @@ func (provider *NebiusProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -265,16 +262,15 @@ func (provider *NebiusProvider) TranscriptionStream(ctx *schemas.BifrostContext, func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Validate request is not nil if request == nil { - return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("image generation request is nil", nil) } // Validate input and prompt are not nil/empty if request.Input == nil || strings.TrimSpace(request.Input.Prompt) == "" { - return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("prompt cannot be empty", nil) } path := providerUtils.GetPathFromContext(ctx, "/v1/images/generations") - providerName := schemas.Nebius // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -309,8 +305,7 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key request, func() (providerUtils.RequestBodyWithExtraParams, error) { return provider.ToNebiusImageGenerationRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -328,16 +323,12 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseNebiusImageError(resp), jsonData, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } response := &schemas.BifrostImageGenerationResponse{} @@ -357,9 +348,6 @@ func (provider *NebiusProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled diff --git a/core/providers/nebius/nebius_test.go b/core/providers/nebius/nebius_test.go index 898617cb1e..da6c4065e8 100644 --- a/core/providers/nebius/nebius_test.go +++ b/core/providers/nebius/nebius_test.go @@ -32,25 +32,25 @@ func TestNebius(t *testing.T) { EmbeddingModel: "BAAI/bge-en-icl", ImageGenerationModel: "black-forest-labs/flux-schnell", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - ImageGeneration: true, - CompleteEnd2End: true, - ImageGenerationStream: false, - Embedding: true, // Nebius supports embeddings - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + ImageGeneration: true, + CompleteEnd2End: true, + ImageGenerationStream: false, + Embedding: true, // Nebius supports embeddings + ListModels: true, }, } diff --git a/core/providers/ollama/ollama.go b/core/providers/ollama/ollama.go index 6710169160..b84d3f7c9c 100644 --- a/core/providers/ollama/ollama.go +++ b/core/providers/ollama/ollama.go @@ -3,7 +3,6 @@ package ollama import ( - "fmt" "strings" "time" @@ -50,11 +49,7 @@ func NewOllamaProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for Ollama - if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for ollama provider") - } - + // BaseURL is optional when keys have ollama_key_config with per-key URLs return &OllamaProvider{ logger: logger, client: client, @@ -69,17 +64,14 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } -// ListModels performs a list models request to Ollama's API. -func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) - } - return openai.HandleOpenAIListModelsRequest( +// listModelsByKey performs a list models request for a single Ollama key. +func (provider *OllamaProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return openai.ListModelsByKey( ctx, provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), + key, + request.Unfiltered, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), @@ -87,12 +79,24 @@ func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []s ) } +// ListModels performs a list models request to Ollama's API. +// Requests are made concurrently per key so that each backend is queried +// with its own URL (from ollama_key_config). +func (provider *OllamaProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return providerUtils.HandleMultipleListModelsRequests( + ctx, + keys, + request, + provider.listModelsByKey, + ) +} + // TextCompletion performs a text completion request to the Ollama API. func (provider *OllamaProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -112,7 +116,7 @@ func (provider *OllamaProvider) TextCompletionStream(ctx *schemas.BifrostContext return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/completions", + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -132,7 +136,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx *schemas.BifrostContext, key return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -154,7 +158,7 @@ func (provider *OllamaProvider) ChatCompletionStream(ctx *schemas.BifrostContext return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -179,9 +183,6 @@ func (provider *OllamaProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -202,7 +203,7 @@ func (provider *OllamaProvider) Embedding(ctx *schemas.BifrostContext, key schem return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.OllamaKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/ollama/ollama_test.go b/core/providers/ollama/ollama_test.go index ad31297046..a9005c1220 100644 --- a/core/providers/ollama/ollama_test.go +++ b/core/providers/ollama/ollama_test.go @@ -29,24 +29,24 @@ func TestOllama(t *testing.T) { TextModel: "", // Ollama doesn't support text completion in newer models EmbeddingModel: "", // Ollama doesn't support embedding Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - FileBase64: false, - FileURL: false, - CompleteEnd2End: true, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + FileBase64: false, + FileURL: false, + CompleteEnd2End: true, + Embedding: false, + ListModels: true, }, } diff --git a/core/providers/openai/batch.go b/core/providers/openai/batch.go index ae095e5c77..ec8ce468bb 100644 --- a/core/providers/openai/batch.go +++ b/core/providers/openai/batch.go @@ -10,10 +10,10 @@ import ( // OpenAIBatchRequest represents the request body for creating a batch. type OpenAIBatchRequest struct { - InputFileID string `json:"input_file_id"` - Endpoint string `json:"endpoint"` - CompletionWindow string `json:"completion_window"` - Metadata map[string]string `json:"metadata,omitempty"` + InputFileID string `json:"input_file_id"` + Endpoint string `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]string `json:"metadata,omitempty"` OutputExpiresAfter *schemas.BatchExpiresAfter `json:"output_expires_after,omitempty"` } @@ -82,7 +82,7 @@ func ToBifrostBatchStatus(status string) schemas.BatchStatus { } // ToBifrostBatchCreateResponse converts OpenAI batch response to Bifrost batch response. -func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse { resp := &schemas.BifrostBatchCreateResponse{ ID: r.ID, Object: r.Object, @@ -95,9 +95,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. OutputFileID: r.OutputFileID, ErrorFileID: r.ErrorFileID, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCreateRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -125,7 +123,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(providerName schemas. } // ToBifrostBatchRetrieveResponse converts OpenAI batch response to Bifrost batch retrieve response. -func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { +func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse { resp := &schemas.BifrostBatchRetrieveResponse{ ID: r.ID, Object: r.Object, @@ -146,9 +144,7 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema ErrorFileID: r.ErrorFileID, Errors: r.Errors, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -174,35 +170,3 @@ func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(providerName schema return resp } - -// splitJSONL splits JSONL content into individual lines. -func splitJSONL(data []byte) [][]byte { - var lines [][]byte - start := 0 - for i, b := range data { - if b == '\n' { - if i > start { - end := i - // Strip trailing \r if present (handle CRLF) - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - start = i + 1 - } - } - if start < len(data) { - end := len(data) - // Strip trailing \r if present - if end > start && data[end-1] == '\r' { - end-- - } - if end > start { - lines = append(lines, data[start:end]) - } - } - return lines -} diff --git a/core/providers/openai/chat.go b/core/providers/openai/chat.go index 2e2bfd43b8..48141f46d1 100644 --- a/core/providers/openai/chat.go +++ b/core/providers/openai/chat.go @@ -104,12 +104,17 @@ func (req *OpenAIChatRequest) filterOpenAISpecificParameters() { // Handle reasoning parameter: OpenAI uses effort-based reasoning // Priority: effort (native) > max_tokens (estimated) if req.ChatParameters.Reasoning != nil { + reasoningCopy := *req.ChatParameters.Reasoning + req.ChatParameters.Reasoning = &reasoningCopy if req.ChatParameters.Reasoning.Effort != nil { // Native field is provided, use it (and clear max_tokens) effort := *req.ChatParameters.Reasoning.Effort - // Convert "minimal" to "low" for non-OpenAI providers - if effort == "minimal" { + // Convert "minimal" to "low"; cap "xhigh"/"max" to "high" — OpenAI tops out at high. + switch effort { + case "minimal": req.ChatParameters.Reasoning.Effort = schemas.Ptr("low") + case "xhigh", "max": + req.ChatParameters.Reasoning.Effort = schemas.Ptr("high") } // Clear max_tokens since OpenAI doesn't use it req.ChatParameters.Reasoning.MaxTokens = nil diff --git a/core/providers/openai/chat_test.go b/core/providers/openai/chat_test.go index f391f821cb..724c438d91 100644 --- a/core/providers/openai/chat_test.go +++ b/core/providers/openai/chat_test.go @@ -2,11 +2,13 @@ package openai import ( "encoding/json" + "strings" "testing" "github.com/bytedance/sonic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/stretchr/testify/require" ) func TestToOpenAIChatRequest_ToolNormalization(t *testing.T) { @@ -78,6 +80,32 @@ func TestToOpenAIChatRequest_ToolNormalization(t *testing.T) { } } +func TestToOpenAIChatRequest_PreservesN(t *testing.T) { + req := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4.1", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("hello"), + }, + }, + }, + Params: &schemas.ChatParameters{ + N: schemas.Ptr(2), + }, + } + + out := ToOpenAIChatRequest(schemas.NewBifrostContext(nil, schemas.NoDeadline), req) + if out == nil { + t.Fatal("expected request") + } + if out.N == nil || *out.N != 2 { + t.Fatalf("expected n=2, got %#v", out.N) + } +} + func TestToOpenAIChatRequest_PreservesPropertyOrder(t *testing.T) { params := &schemas.ToolFunctionParameters{ Type: "object", @@ -277,7 +305,6 @@ func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *tes func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) @@ -307,6 +334,68 @@ func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *tes } } +// TestToOpenAIChatRequest_AnnotationsNotInWirePayload verifies that MCPToolAnnotations +// (stored on ChatTool with json:"-") are never included in the JSON body sent to OpenAI. +func TestToOpenAIChatRequest_AnnotationsNotInWirePayload(t *testing.T) { + readOnly := true + + bifrostReq := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o", + Input: []schemas.ChatMessage{ + {Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}}, + }, + Params: &schemas.ChatParameters{ + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "read_file", + Description: schemas.Ptr("Read a file"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: schemas.NewOrderedMapFromPairs( + schemas.KV("path", map[string]interface{}{"type": "string"}), + ), + Required: []string{"path"}, + }, + }, + Annotations: &schemas.MCPToolAnnotations{ + Title: "File Reader", + ReadOnlyHint: &readOnly, + }, + }, + }, + }, + } + + ctx, cancel := schemas.NewBifrostContextWithCancel(nil) + defer cancel() + + result := ToOpenAIChatRequest(ctx, bifrostReq) + require.NotNil(t, result) + + wireBody, err := json.Marshal(result) + require.NoError(t, err) + s := string(wireBody) + + // Annotations must be absent from the wire payload + if strings.Contains(s, "annotations") { + t.Errorf("annotations field leaked into OpenAI wire payload: %s", s) + } + if strings.Contains(s, "readOnlyHint") { + t.Errorf("readOnlyHint leaked into OpenAI wire payload: %s", s) + } + if strings.Contains(s, "File Reader") { + t.Errorf("annotation title leaked into OpenAI wire payload: %s", s) + } + + // The function definition must still be intact + if !strings.Contains(s, "read_file") { + t.Errorf("function name missing from OpenAI wire payload: %s", s) + } +} + func TestApplyXAICompatibility(t *testing.T) { tests := []struct { name string diff --git a/core/providers/openai/errors.go b/core/providers/openai/errors.go index 6a5bc1ce08..69d0aff407 100644 --- a/core/providers/openai/errors.go +++ b/core/providers/openai/errors.go @@ -10,10 +10,10 @@ import ( ) // ErrorConverter is a function that converts provider-specific error responses to BifrostError. -type ErrorConverter func(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError +type ErrorConverter func(resp *fasthttp.Response) *schemas.BifrostError // ParseOpenAIError parses OpenAI error responses. -func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp schemas.BifrostError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -49,11 +49,6 @@ func ParseOpenAIError(resp *fasthttp.Response, requestType schemas.RequestType, } // Set ExtraFields unconditionally so provider/model/request metadata is always attached - if bifrostErr != nil { - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - } return bifrostErr } diff --git a/core/providers/openai/errors_test.go b/core/providers/openai/errors_test.go index f33008600b..1132a92723 100644 --- a/core/providers/openai/errors_test.go +++ b/core/providers/openai/errors_test.go @@ -12,7 +12,7 @@ func TestParseOpenAIError_FallbackMessageWhenProviderBodyIsNonOpenAIShape(t *tes resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"detail":[{"loc":["body","messages",0,"role"],"msg":"value is not a valid enumeration member"}]}`) - errResp := ParseOpenAIError(&resp, schemas.ResponsesStreamRequest, schemas.Cerebras, "llama3.1-8b") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -29,7 +29,7 @@ func TestParseOpenAIError_PreservesProviderMessageWhenPresent(t *testing.T) { resp.SetStatusCode(fasthttp.StatusUnprocessableEntity) resp.SetBodyString(`{"error":{"message":"unsupported role: developer","type":"invalid_request_error","param":"messages.0.role","code":"invalid_value"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -43,7 +43,7 @@ func TestParseOpenAIError_FallbackMessageWhenBodyIsEmpty(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBody(nil) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -59,7 +59,7 @@ func TestParseOpenAIError_WhitespaceProviderMessageFallsBack(t *testing.T) { resp.SetStatusCode(fasthttp.StatusBadRequest) resp.SetBodyString(`{"error":{"message":" ","type":"invalid_request_error"}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } @@ -73,7 +73,7 @@ func TestParseOpenAIError_DefaultStatusCodeFallsBackWithStatusNumber(t *testing. // fasthttp defaults zero-value response status code to 200. resp.SetBodyString(`{"error":{"message":""}}`) - errResp := ParseOpenAIError(&resp, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4o") + errResp := ParseOpenAIError(&resp) if errResp == nil || errResp.Error == nil { t.Fatal("expected non-nil error response") } diff --git a/core/providers/openai/files.go b/core/providers/openai/files.go index bbaf2b2f70..133250cac7 100644 --- a/core/providers/openai/files.go +++ b/core/providers/openai/files.go @@ -55,7 +55,7 @@ func ToBifrostFileStatus(status string) schemas.FileStatus { } // ToBifrostFileUploadResponse converts OpenAI file response to Bifrost file upload response. -func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { +func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse { resp := &schemas.BifrostFileUploadResponse{ ID: r.ID, Object: r.Object, @@ -67,9 +67,7 @@ func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(providerName schemas.Mo StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -97,9 +95,7 @@ func (r *OpenAIFileResponse) ToBifrostFileRetrieveResponse(providerName schemas. StatusDetails: r.StatusDetails, StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } diff --git a/core/providers/openai/images.go b/core/providers/openai/images.go index f183e17ec5..9176f1e1e7 100644 --- a/core/providers/openai/images.go +++ b/core/providers/openai/images.go @@ -125,18 +125,18 @@ func ToOpenAIImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) *Open func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageEditRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add prompt field (required) if err := writer.WriteField("prompt", openaiReq.Input.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add stream field when requesting streaming if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return providerUtils.NewBifrostOperationError("failed to write stream field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write stream field", err) } } @@ -168,71 +168,71 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err) } if _, err := part.Write(imageInput.Image); err != nil { - return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err, providerName) + return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err) } } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Quality != nil { if err := writer.WriteField("quality", *openaiReq.Quality); err != nil { - return providerUtils.NewBifrostOperationError("failed to write quality field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write quality field", err) } } if openaiReq.Background != nil { if err := writer.WriteField("background", *openaiReq.Background); err != nil { - return providerUtils.NewBifrostOperationError("failed to write background field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write background field", err) } } if openaiReq.InputFidelity != nil { if err := writer.WriteField("input_fidelity", *openaiReq.InputFidelity); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err) } } if openaiReq.PartialImages != nil { if err := writer.WriteField("partial_images", strconv.Itoa(*openaiReq.PartialImages)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write partial_images field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write partial_images field", err) } } if openaiReq.OutputFormat != nil { if err := writer.WriteField("output_format", *openaiReq.OutputFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_format field", err) } } if openaiReq.OutputCompression != nil { if err := writer.WriteField("output_compression", strconv.Itoa(*openaiReq.OutputCompression)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write output_compression field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write output_compression field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } @@ -260,16 +260,16 @@ func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq * "Content-Type": {maskMimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create mask form part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create mask form part", err) } if _, err := maskPart.Write(openaiReq.Mask); err != nil { - return providerUtils.NewBifrostOperationError("failed to write mask data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write mask data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil @@ -299,12 +299,12 @@ func ToOpenAIImageVariationRequest(bifrostReq *schemas.BifrostImageVariationRequ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageVariationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add model field (required) if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } // Add image file (required) if openaiReq.Input == nil || openaiReq.Input.Image.Image == nil || len(openaiReq.Input.Image.Image) == 0 { - return providerUtils.NewBifrostOperationError("image is required", nil, providerName) + return providerUtils.NewBifrostOperationError("image is required", nil) } // Detect MIME type @@ -320,41 +320,41 @@ func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openai "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create image part", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create image part", err) } if _, err := part.Write(openaiReq.Input.Image.Image); err != nil { - return providerUtils.NewBifrostOperationError("failed to write image data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write image data", err) } // Add optional parameters if openaiReq.N != nil { if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil { - return providerUtils.NewBifrostOperationError("failed to write n field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write n field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return providerUtils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Size != nil { if err := writer.WriteField("size", *openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } if openaiReq.User != nil { if err := writer.WriteField("user", *openaiReq.User); err != nil { - return providerUtils.NewBifrostOperationError("failed to write user field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write user field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/large_payload.go b/core/providers/openai/large_payload.go index 461f3417de..fe3aaf1812 100644 --- a/core/providers/openai/large_payload.go +++ b/core/providers/openai/large_payload.go @@ -42,8 +42,6 @@ func handleOpenAILargePayloadPassthrough( key schemas.Key, extraHeaders map[string]string, providerName schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) (*largePayloadResult, *schemas.BifrostError, bool) { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) @@ -91,7 +89,7 @@ func handleOpenAILargePayloadPassthrough( // Error responses are always small — materialize stream body for error parsing if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - parsedErr := ParseOpenAIError(resp, requestType, providerName, model) + parsedErr := ParseOpenAIError(resp) fasthttp.ReleaseResponse(resp) return nil, parsedErr, true } @@ -126,7 +124,7 @@ func finalizeOpenAIResponse( providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, *largePayloadResult, *schemas.BifrostError) { - body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, logger) + body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, logger) if bifrostErr != nil { fasthttp.ReleaseResponse(resp) return nil, nil, bifrostErr diff --git a/core/providers/openai/models.go b/core/providers/openai/models.go index d00a8af112..a76d350d28 100644 --- a/core/providers/openai/models.go +++ b/core/providers/openai/models.go @@ -1,13 +1,14 @@ package openai import ( - "slices" + "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) // ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response -func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -16,38 +17,39 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe Data: make([]schemas.Model, 0, len(response.Data)), } - includedModels := make(map[string]bool) - for _, model := range response.Data { - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, model.ID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, model.ID) { - continue - } - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + model.ID, - Created: model.Created, - OwnedBy: schemas.Ptr(model.OwnedBy), - ContextLength: model.ContextWindow, - }) - includedModels[model.ID] = true + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue + included := make(map[string]bool) + + for _, model := range response.Data { + for _, result := range pipeline.FilterModel(model.ID) { + entry := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Created: model.Created, + OwnedBy: schemas.Ptr(model.OwnedBy), + ContextLength: model.ContextWindow, } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) } + bifrostResponse.Data = append(bifrostResponse.Data, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + return bifrostResponse } diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 6309445c0f..a4e06dac47 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -166,7 +166,7 @@ func ListModelsByKey( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - bifrostErr := ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := ParseOpenAIError(resp) return nil, bifrostErr } @@ -181,10 +181,8 @@ func ListModelsByKey( return nil, bifrostErr } - response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, unfiltered) + response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, unfiltered) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -289,22 +287,22 @@ func HandleOpenAITextCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TextCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTextCompletionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -313,8 +311,7 @@ func HandleOpenAITextCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -335,9 +332,9 @@ func HandleOpenAITextCompletionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -349,7 +346,7 @@ func HandleOpenAITextCompletionRequest( return &schemas.BifrostTextCompletionResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TextCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -367,9 +364,6 @@ func HandleOpenAITextCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.TextCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -455,8 +449,7 @@ func HandleOpenAITextCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr @@ -501,9 +494,9 @@ func HandleOpenAITextCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -514,9 +507,9 @@ func HandleOpenAITextCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.TextCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -531,11 +524,12 @@ func HandleOpenAITextCompletionStreaming( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TextCompletionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -557,7 +551,7 @@ func HandleOpenAITextCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -584,7 +578,7 @@ func HandleOpenAITextCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -595,11 +589,6 @@ func HandleOpenAITextCompletionStreaming( rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { // TODO fix this - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -618,11 +607,6 @@ func HandleOpenAITextCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -699,9 +683,6 @@ func HandleOpenAITextCompletionStreaming( if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil { chunkIndex++ - response.ExtraFields.RequestType = schemas.TextCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -719,7 +700,7 @@ func HandleOpenAITextCompletionStreaming( } } - response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest, providerName, request.Model) + response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest) if postResponseConverter != nil { response = postResponseConverter(response) if response == nil { @@ -811,22 +792,22 @@ func HandleOpenAIChatCompletionRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ChatCompletionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostChatResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -835,8 +816,7 @@ func HandleOpenAIChatCompletionRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -858,9 +838,9 @@ func HandleOpenAIChatCompletionRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -872,7 +852,7 @@ func HandleOpenAIChatCompletionRequest( return &schemas.BifrostChatResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ChatCompletionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } response := &schemas.BifrostChatResponse{} @@ -890,9 +870,6 @@ func HandleOpenAIChatCompletionRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ChatCompletionRequest response.ExtraFields.Latency = latency.Milliseconds() // Set raw request if enabled @@ -1008,8 +985,7 @@ func HandleOpenAIChatCompletionStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1054,9 +1030,9 @@ func HandleOpenAIChatCompletionStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1067,9 +1043,9 @@ func HandleOpenAIChatCompletionStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ChatCompletionStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1082,19 +1058,14 @@ func HandleOpenAIChatCompletionStreaming( // Create response channel responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) - // Determine request type for cleanup - streamRequestType := schemas.ChatCompletionStreamRequest - if isResponsesToChatCompletionsFallback { - streamRequestType = schemas.ResponsesStreamRequest - } - // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, streamRequestType, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } // Release the responses stream state if it was acquired (for ResponsesToChatCompletions fallback) schemas.ReleaseChatToResponsesStreamState(responsesStreamState) @@ -1118,7 +1089,7 @@ func HandleOpenAIChatCompletionStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1132,6 +1103,8 @@ func HandleOpenAIChatCompletionStreaming( var finishReason *string var messageID string + var modelName string + var created int forwardedTerminalFinishReason := false for { @@ -1147,7 +1120,7 @@ func HandleOpenAIChatCompletionStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, streamRequestType, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) return } break @@ -1160,11 +1133,6 @@ func HandleOpenAIChatCompletionStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -1178,11 +1146,6 @@ func HandleOpenAIChatCompletionStreaming( if customResponseHandler != nil { rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse) if handlerErr != nil { - handlerErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: streamRequestType, - } if sendBackRawRequest { handlerErr.ExtraFields.RawRequest = rawRequest } @@ -1213,11 +1176,6 @@ func HandleOpenAIChatCompletionStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: streamRequestType, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1235,9 +1193,6 @@ func HandleOpenAIChatCompletionStreaming( return } - response.ExtraFields.RequestType = streamRequestType - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber if sendBackRawResponse { @@ -1300,6 +1255,10 @@ func HandleOpenAIChatCompletionStreaming( response.Usage = nil } + if response.Model != "" { + modelName = response.Model + } + // Skip empty responses or responses without choices if len(response.Choices) == 0 { continue @@ -1315,11 +1274,14 @@ func HandleOpenAIChatCompletionStreaming( if response.ID != "" && messageID == "" { messageID = response.ID } + if response.Created != 0 && created == 0 { + created = response.Created + } // Handle regular content chunks, including reasoning if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil && - (choice.ChatStreamResponseChoice.Delta.Content != nil || + ((choice.ChatStreamResponseChoice.Delta.Content != nil && *choice.ChatStreamResponseChoice.Delta.Content != "") || choice.ChatStreamResponseChoice.Delta.Reasoning != nil || len(choice.ChatStreamResponseChoice.Delta.ReasoningDetails) > 0 || choice.ChatStreamResponseChoice.Delta.Audio != nil || @@ -1329,9 +1291,6 @@ func HandleOpenAIChatCompletionStreaming( } chunkIndex++ - response.ExtraFields.RequestType = schemas.ChatCompletionStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds() lastChunkTime = time.Now() @@ -1355,7 +1314,7 @@ func HandleOpenAIChatCompletionStreaming( if forwardedTerminalFinishReason { finalFinishReason = nil } - response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, streamRequestType, providerName, request.Model) + response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, modelName, created) if postResponseConverter != nil { response = postResponseConverter(response) } @@ -1442,21 +1401,21 @@ func HandleOpenAIResponsesRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ResponsesRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostResponsesResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1466,8 +1425,7 @@ func HandleOpenAIResponsesRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1489,9 +1447,9 @@ func HandleOpenAIResponsesRequest( providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1502,7 +1460,7 @@ func HandleOpenAIResponsesRequest( if lpResult != nil { return &schemas.BifrostResponsesResponse{ Model: request.Model, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ResponsesRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1520,9 +1478,6 @@ func HandleOpenAIResponsesRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ResponsesRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -1619,8 +1574,7 @@ func HandleOpenAIResponsesStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1664,9 +1618,9 @@ func HandleOpenAIResponsesStreaming( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -1677,9 +1631,9 @@ func HandleOpenAIResponsesStreaming( defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) if customErrorConverter != nil { - return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ResponsesStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -1694,11 +1648,12 @@ func HandleOpenAIResponsesStreaming( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ResponsesStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -1720,7 +1675,7 @@ func HandleOpenAIResponsesStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -1742,7 +1697,7 @@ func HandleOpenAIResponsesStreaming( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -1754,11 +1709,6 @@ func HandleOpenAIResponsesStreaming( if customResponseHandler != nil { rawRequest, rawResponse, bifrostErr := customResponseHandler([]byte(jsonData), &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } if sendBackRawRequest { bifrostErr.ExtraFields.RawRequest = rawRequest } @@ -1792,11 +1742,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Message != nil { @@ -1829,11 +1774,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeFailed)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Response != nil && response.Response.Error != nil { bifrostErr.Error.Message = response.Response.Error.Message @@ -1844,11 +1784,7 @@ func HandleOpenAIResponsesStreaming( return } - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber - if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete { // Set raw request if enabled if sendBackRawRequest { @@ -1866,7 +1802,6 @@ func HandleOpenAIResponsesStreaming( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan) } } - }() return responseChan, nil @@ -1937,22 +1872,22 @@ func HandleOpenAIEmbeddingRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.EmbeddingRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostEmbeddingResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -1962,8 +1897,7 @@ func HandleOpenAIEmbeddingRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIEmbeddingRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1984,7 +1918,7 @@ func HandleOpenAIEmbeddingRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.EmbeddingRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -1996,7 +1930,7 @@ func HandleOpenAIEmbeddingRequest( return &schemas.BifrostEmbeddingResponse{ Model: request.Model, Usage: lpResult.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.EmbeddingRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2014,9 +1948,6 @@ func HandleOpenAIEmbeddingRequest( return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.EmbeddingRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -2095,22 +2026,21 @@ func HandleOpenAISpeechRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.SpeechRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } // Speech response is raw audio bytes (MP3/WAV), not JSON return &schemas.BifrostSpeechResponse{ Audio: lpResult.ResponseBody, - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }, - providerName) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil }) if bifrostErr != nil { return nil, bifrostErr } @@ -2131,7 +2061,7 @@ func HandleOpenAISpeechRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Get the binary audio data from the response body @@ -2142,7 +2072,7 @@ func HandleOpenAISpeechRequest( } if lpResult != nil { return &schemas.BifrostSpeechResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.SpeechRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2152,9 +2082,6 @@ func HandleOpenAISpeechRequest( bifrostResponse := &schemas.BifrostSpeechResponse{ Audio: body, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -2177,7 +2104,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHo for _, model := range providerUtils.UnsupportedSpeechStreamModels { if model == request.Model { - return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil) } } @@ -2262,8 +2189,7 @@ func HandleOpenAISpeechStreamRequest( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2289,9 +2215,9 @@ func HandleOpenAISpeechStreamRequest( }, jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Store provider response headers in context before status check so error responses also forward them @@ -2301,7 +2227,7 @@ func HandleOpenAISpeechStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.SpeechStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -2316,11 +2242,12 @@ func HandleOpenAISpeechStreamRequest( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.SpeechStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2342,7 +2269,7 @@ func HandleOpenAISpeechStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2366,7 +2293,7 @@ func HandleOpenAISpeechStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2378,11 +2305,6 @@ func HandleOpenAISpeechStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.SpeechStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -2408,11 +2330,8 @@ func HandleOpenAISpeechStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.SpeechStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2433,7 +2352,6 @@ func HandleOpenAISpeechStreamRequest( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan) } - }() return responseChan, nil @@ -2474,7 +2392,7 @@ func HandleOpenAITranscriptionRequest( logger schemas.Logger, ) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.TranscriptionRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } @@ -2482,13 +2400,13 @@ func HandleOpenAITranscriptionRequest( if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostTranscriptionResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2517,7 +2435,7 @@ func HandleOpenAITranscriptionRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } // Create multipart form @@ -2544,7 +2462,7 @@ func HandleOpenAITranscriptionRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.TranscriptionRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -2554,7 +2472,7 @@ func HandleOpenAITranscriptionRequest( } if lpResult != nil { return &schemas.BifrostTranscriptionResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.TranscriptionRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2574,7 +2492,12 @@ func HandleOpenAITranscriptionRequest( // Parse OpenAI's transcription response directly into BifrostTranscribe response := &schemas.BifrostTranscriptionResponse{} var rawResponse interface{} - if customResponseHandler != nil { + if request.Params != nil && schemas.IsPlainTextTranscriptionFormat(request.Params.ResponseFormat) { + response.Text = string(copiedResponseBody) + if sendBackRawResponse { + rawResponse = string(copiedResponseBody) + } + } else if customResponseHandler != nil { _, rawResponse, bifrostErr = customResponseHandler(copiedResponseBody, response, nil, false, sendBackRawResponse) } else { if err := sonic.Unmarshal(copiedResponseBody, response); err != nil { @@ -2588,15 +2511,15 @@ func HandleOpenAITranscriptionRequest( }, } } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - //TODO: add HandleProviderResponse here + // TODO: add HandleProviderResponse here // Parse raw response for RawResponse field if sendBackRawResponse { if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err) } } } @@ -2606,9 +2529,6 @@ func HandleOpenAITranscriptionRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -2670,7 +2590,7 @@ func HandleOpenAITranscriptionStreamRequest( // Use centralized converter reqBody := ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) if postRequestConverter != nil { @@ -2731,9 +2651,9 @@ func HandleOpenAITranscriptionStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -2743,7 +2663,7 @@ func HandleOpenAITranscriptionStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -2758,11 +2678,12 @@ func HandleOpenAITranscriptionStreamRequest( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -2784,7 +2705,7 @@ func HandleOpenAITranscriptionStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -2809,7 +2730,7 @@ func HandleOpenAITranscriptionStreamRequest( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -2820,11 +2741,6 @@ func HandleOpenAITranscriptionStreamRequest( if customResponseHandler != nil { _, _, bifrostErr = customResponseHandler([]byte(jsonData), response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } if sendBackRawResponse { bifrostErr.ExtraFields.RawResponse = jsonData } @@ -2839,13 +2755,9 @@ func HandleOpenAITranscriptionStreamRequest( var bifrostErrVal schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErrVal); err == nil { if bifrostErrVal.Error != nil && bifrostErrVal.Error.Message != "" { - bifrostErrVal.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, nil, nil, false, sendBackRawResponse), responseChan, logger) + respBody := append([]byte(nil), resp.Body()...) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger) return } } @@ -2869,11 +2781,8 @@ func HandleOpenAITranscriptionStreamRequest( chunkIndex++ response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() @@ -2895,7 +2804,6 @@ func HandleOpenAITranscriptionStreamRequest( providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan) } - }() return responseChan, nil @@ -2905,8 +2813,8 @@ func HandleOpenAITranscriptionStreamRequest( // It formats the request, sends it to OpenAI, and processes the response. // Returns a BifrostResponse containing the bifrost response or an error if the request fails. func (provider *OpenAIProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, - req *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - + req *schemas.BifrostImageGenerationRequest, +) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil { return nil, err } @@ -2939,7 +2847,6 @@ func HandleOpenAIImageGenerationRequest( sendBackRawResponse bool, logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - // Create request req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2965,20 +2872,20 @@ func HandleOpenAIImageGenerationRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageGenerationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -2988,8 +2895,7 @@ func HandleOpenAIImageGenerationRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIImageGenerationRequest(request), nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3010,7 +2916,7 @@ func HandleOpenAIImageGenerationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -3020,7 +2926,7 @@ func HandleOpenAIImageGenerationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageGenerationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -3032,9 +2938,6 @@ func HandleOpenAIImageGenerationRequest( return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageGenerationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -3060,9 +2963,8 @@ func (provider *OpenAIProvider) ImageGenerationStream( key schemas.Key, request *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } // Check if image generation stream is allowed for this provider @@ -3109,7 +3011,6 @@ func HandleOpenAIImageGenerationStreaming( postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse, logger schemas.Logger, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - // Set headers headers := map[string]string{ "Content-Type": "application/json", @@ -3137,8 +3038,7 @@ func HandleOpenAIImageGenerationStreaming( } } return reqBody, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -3183,9 +3083,9 @@ func HandleOpenAIImageGenerationStreaming( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -3195,7 +3095,7 @@ func HandleOpenAIImageGenerationStreaming( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageGenerationStreamRequest, providerName, request.Model), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -3210,11 +3110,12 @@ func HandleOpenAIImageGenerationStreaming( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -3236,7 +3137,7 @@ func HandleOpenAIImageGenerationStreaming( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -3263,7 +3164,7 @@ func HandleOpenAIImageGenerationStreaming( if readErr != nil { if readErr != io.EOF { logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -3275,11 +3176,6 @@ func HandleOpenAIImageGenerationStreaming( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -3299,11 +3195,6 @@ func HandleOpenAIImageGenerationStreaming( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -3404,11 +3295,8 @@ func HandleOpenAIImageGenerationStreaming( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -3471,7 +3359,6 @@ func HandleOpenAIImageGenerationStreaming( return } } - }() return responseChan, nil @@ -3515,7 +3402,7 @@ func (provider *OpenAIProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3544,7 +3431,7 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3585,12 +3472,12 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDownloadRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Get content type from response @@ -3608,8 +3495,6 @@ func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key s Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDownloadRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3625,7 +3510,7 @@ func (provider *OpenAIProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -3695,10 +3580,10 @@ func HandleOpenAIVideoGenerationRequest( // Use centralized converter reqBody, err := ToOpenAIVideoGenerationRequest(request) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil) } // Create multipart form @@ -3724,12 +3609,12 @@ func HandleOpenAIVideoGenerationRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoGenerationRequest, providerName, request.Model) + return nil, ParseOpenAIError(resp) } responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Check for empty response @@ -3755,9 +3640,6 @@ func HandleOpenAIVideoGenerationRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoGenerationRequest, - Provider: providerName, - ModelRequested: request.Model, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3822,12 +3704,12 @@ func HandleOpenAIVideoRetrieveRequest( if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRetrieveRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoGenerationResponse{} @@ -3869,8 +3751,6 @@ func HandleOpenAIVideoRetrieveRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -3922,12 +3802,12 @@ func HandleOpenAIVideoDeleteRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoDeleteRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Parse OpenAI's video response @@ -3941,8 +3821,6 @@ func HandleOpenAIVideoDeleteRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4015,12 +3893,12 @@ func HandleOpenAIVideoListRequest( // Handle error response if resp.StatusCode() != fasthttp.StatusOK { logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } response := &schemas.BifrostVideoListResponse{} @@ -4047,8 +3925,6 @@ func HandleOpenAIVideoListRequest( } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, } @@ -4118,20 +3994,20 @@ func HandleOpenAICountTokensRequest( } // Large payload passthrough: stream body directly without JSON marshaling - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.CountTokensRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostCountTokensResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4140,9 +4016,7 @@ func HandleOpenAICountTokensRequest( request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIResponsesRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -4163,7 +4037,7 @@ func HandleOpenAICountTokensRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.CountTokensRequest, providerName, request.Model), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4173,7 +4047,7 @@ func HandleOpenAICountTokensRequest( } if lpResult != nil { return &schemas.BifrostCountTokensResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.CountTokensRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4186,9 +4060,6 @@ func HandleOpenAICountTokensRequest( } response.Model = request.Model - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4236,26 +4107,26 @@ func HandleOpenAIImageEditRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageEditRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageEditRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4302,7 +4173,7 @@ func HandleOpenAIImageEditRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4312,7 +4183,7 @@ func HandleOpenAIImageEditRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageEditRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4321,9 +4192,6 @@ func HandleOpenAIImageEditRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageEditRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4385,10 +4253,9 @@ func HandleOpenAIImageEditStreamRequest( postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse, logger schemas.Logger, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - reqBody := ToOpenAIImageEditRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -4448,9 +4315,9 @@ func HandleOpenAIImageEditStreamRequest( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -4459,7 +4326,7 @@ func HandleOpenAIImageEditStreamRequest( if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageEditStreamRequest, providerName, request.Model), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -4474,11 +4341,12 @@ func HandleOpenAIImageEditStreamRequest( // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -4500,7 +4368,7 @@ func HandleOpenAIImageEditStreamRequest( // on non-line-delimited data (e.g. provider returned JSON instead of SSE). if providerUtils.DrainNonSSEStreamResponse(resp) { ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger) return } @@ -4527,7 +4395,7 @@ func HandleOpenAIImageEditStreamRequest( if readErr != nil { if readErr != io.EOF { logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageEditStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -4539,11 +4407,6 @@ func HandleOpenAIImageEditStreamRequest( var bifrostErr schemas.BifrostError if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil { if bifrostErr.Error != nil && bifrostErr.Error.Message != "" { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return @@ -4563,11 +4426,6 @@ func HandleOpenAIImageEditStreamRequest( bifrostErr := &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - }, } // Guard access to response.Error fields if response.Error != nil { @@ -4668,11 +4526,8 @@ func HandleOpenAIImageEditStreamRequest( Background: response.Background, OutputFormat: response.OutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, // Chunk order within this image - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, // Chunk order within this image + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -4731,7 +4586,6 @@ func HandleOpenAIImageEditStreamRequest( return } } - }() return responseChan, nil @@ -4773,26 +4627,26 @@ func HandleOpenAIImageVariationRequest( logger schemas.Logger, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { // Large payload passthrough: stream multipart body directly without parsing - if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, request.Model, schemas.ImageVariationRequest, logger); handled { + if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled { if lpErr != nil { return nil, lpErr } if len(lpResult.ResponseBody) > 0 { response := &schemas.BifrostImageGenerationResponse{} if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) } - response.ExtraFields = schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency} + response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency} return response, nil } return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } openaiReq := ToOpenAIImageVariationRequest(request) if openaiReq == nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil) } // Create request @@ -4838,7 +4692,7 @@ func HandleOpenAIImageVariationRequest( if resp.StatusCode() != fasthttp.StatusOK { providerUtils.MaterializeStreamErrorBody(ctx, resp) - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.ImageVariationRequest, providerName, request.Model), bodyData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse) } bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger) @@ -4848,7 +4702,7 @@ func HandleOpenAIImageVariationRequest( } if lpResult != nil { return &schemas.BifrostImageGenerationResponse{ - ExtraFields: schemas.BifrostResponseExtraFields{Provider: providerName, ModelRequested: request.Model, RequestType: schemas.ImageVariationRequest, Latency: lpResult.Latency}, + ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}, }, nil } @@ -4857,9 +4711,6 @@ func HandleOpenAIImageVariationRequest( if bifrostErr != nil { return nil, bifrostErr } - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - response.ExtraFields.RequestType = schemas.ImageVariationRequest response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -4876,14 +4727,12 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, err } - providerName := provider.GetProviderKey() - if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } if request.Purpose == "" { - return nil, providerUtils.NewBifrostOperationError("purpose is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("purpose is required", nil) } // Create multipart form data @@ -4892,16 +4741,16 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Add purpose field if err := writer.WriteField("purpose", string(request.Purpose)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err) } // Add expires_after fields if provided if request.ExpiresAfter != nil { if err := writer.WriteField("expires_after[anchor]", request.ExpiresAfter.Anchor); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err) } if err := writer.WriteField("expires_after[seconds]", fmt.Sprintf("%d", request.ExpiresAfter.Seconds)); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err) } } @@ -4912,14 +4761,14 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche } part, err := writer.CreateFormFile("file", filename) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -4949,13 +4798,13 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileUploadRequest, providerName, "") + provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body())) + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var openAIResp OpenAIFileResponse @@ -4966,7 +4815,7 @@ func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key sche return nil, bifrostErr } - fileResponse := openAIResp.ToBifrostFileUploadResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) + fileResponse := openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) fileResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) return fileResponse, nil } @@ -4985,7 +4834,7 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -4996,10 +4845,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -5049,12 +4894,12 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.FileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIFileListResponse @@ -5090,8 +4935,6 @@ func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []sch Data: files, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -5112,7 +4955,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5147,7 +4990,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5157,7 +5000,7 @@ func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5188,7 +5031,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -5223,7 +5066,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5233,7 +5076,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5254,9 +5097,7 @@ func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []s Object: openAIResp.Object, Deleted: openAIResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5283,7 +5124,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } var lastErr *schemas.BifrostError @@ -5314,7 +5155,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - lastErr = ParseOpenAIError(resp, schemas.FileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5324,7 +5165,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5343,9 +5184,7 @@ func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys [] Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileContentRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, }, nil } @@ -5362,10 +5201,10 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } if request.Input == nil || request.Input.Prompt == "" { - return nil, providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("prompt is required", nil) } jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -5373,8 +5212,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIVideoRemixRequest(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -5412,12 +5250,12 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body())) - return nil, ParseOpenAIError(resp, schemas.VideoRemixRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Parse OpenAI's video response @@ -5435,9 +5273,7 @@ func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key sche } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.VideoRemixRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } if sendBackRawResponse { @@ -5456,8 +5292,6 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, err } - providerName := provider.GetProviderKey() - inputFileID := request.InputFileID // If no file_id provided but inline requests are available, upload them first @@ -5465,7 +5299,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Convert inline requests to JSONL format jsonlData, err := ConvertRequestsToJSONL(request.Requests) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err) } // Upload the file with purpose "batch" @@ -5484,12 +5318,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Validate that we have a file ID (either provided or uploaded) if inputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil) } // Validate that we have an endpoint if request.Endpoint == "" { - return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil) } // Create request @@ -5524,7 +5358,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch jsonData, err := providerUtils.MarshalSorted(openAIReq) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } req.SetBody(jsonData) @@ -5540,12 +5374,12 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp, schemas.BatchCreateRequest, providerName, ""), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } var openAIResp OpenAIBatchResponse @@ -5554,7 +5388,7 @@ func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key sch return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse) } - return openAIResp.ToBifrostBatchCreateResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil + return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil } // BatchList lists batch jobs using serial pagination across keys. @@ -5564,14 +5398,13 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) // Initialize serial pagination helper helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -5582,10 +5415,6 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Object: "list", Data: []schemas.BifrostBatchRetrieveResponse{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - }, }, nil } @@ -5629,12 +5458,12 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.BatchListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var openAIResp OpenAIBatchListResponse @@ -5647,7 +5476,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data)) var lastBatchID string for _, batch := range openAIResp.Data { - batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) + batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)) lastBatchID = batch.ID } @@ -5661,9 +5490,7 @@ func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []sc Data: batches, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchListRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } if nextCursor != "" { @@ -5680,10 +5507,9 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, request.Provider) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5715,7 +5541,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5725,7 +5551,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5741,8 +5567,7 @@ func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - result := openAIResp.ToBifrostBatchRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) - result.ExtraFields.RequestType = schemas.BatchRetrieveRequest + result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse) return result, nil } @@ -5756,10 +5581,9 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -5791,7 +5615,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchCancelRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5801,7 +5625,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5824,9 +5648,7 @@ func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys [] CancellingAt: openAIResp.CancellingAt, CancelledAt: openAIResp.CancelledAt, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchCancelRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5866,11 +5688,9 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if request.BatchID == "" { - return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil, schemas.OpenAI) + return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil) } - providerName := provider.GetProviderKey() - // First, retrieve the batch to get the output_file_id (this already iterates over keys) batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{ Provider: request.Provider, @@ -5881,7 +5701,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ } if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" { - return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil) } // Download the output file - try each key @@ -5911,7 +5731,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.BatchResultsRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -5921,7 +5741,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -5945,9 +5765,7 @@ func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys [ BatchID: request.BatchID, Results: results, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.BatchResultsRequest, - Provider: providerName, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -5967,14 +5785,12 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.Name == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil) } // Build request body @@ -6010,7 +5826,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key jsonBody, err := providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Create request @@ -6039,7 +5855,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated { - return nil, ParseOpenAIError(resp, schemas.ContainerCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6073,9 +5889,7 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6092,16 +5906,14 @@ func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key // ContainerList lists containers via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } @@ -6115,7 +5927,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6126,10 +5938,6 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys Object: "list", Data: []schemas.ContainerObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - }, }, nil } @@ -6176,7 +5984,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, ParseOpenAIError(resp, schemas.ContainerListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Parse response @@ -6211,9 +6019,7 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6234,20 +6040,18 @@ func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys // ContainerRetrieve retrieves a specific container via OpenAI's API. func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerRetrieveRequest); err != nil { @@ -6282,7 +6086,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6322,9 +6126,7 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k MemoryLimit: containerResp.MemoryLimit, Metadata: containerResp.Metadata, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6345,20 +6147,18 @@ func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, k // ContainerDelete deletes a container via OpenAI's API. func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("provider config not found", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("provider config not found", nil) } } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("container_id is required", nil) } if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerDeleteRequest); err != nil { @@ -6393,7 +6193,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - lastErr = ParseOpenAIError(resp, schemas.ContainerDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6421,9 +6221,7 @@ func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, key Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6452,14 +6250,12 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } // Create request @@ -6476,7 +6272,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle file upload (multipart only) if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil) } // Multipart file upload @@ -6486,13 +6282,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Add file part, err := writer.CreateFormFile("file", "file") if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err) } if _, err = part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err) } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err) } req.Header.Set("Content-Type", writer.FormDataContentType()) req.SetBody(body.Bytes()) @@ -6510,13 +6306,13 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileCreateRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -6545,9 +6341,7 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileCreateRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6565,21 +6359,19 @@ func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, // ContainerFileList lists files in a container via OpenAI's API. // Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next. func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6593,7 +6385,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // Initialize serial pagination helper for multi-key support helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -6604,10 +6396,6 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k Object: "list", Data: []schemas.ContainerFileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - }, }, nil } @@ -6653,13 +6441,13 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k } if resp.StatusCode() >= 400 { - return nil, ParseOpenAIError(resp, schemas.ContainerFileListRequest, providerName, "") + return nil, ParseOpenAIError(resp) } // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var listResp struct { @@ -6691,9 +6479,7 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k LastID: listResp.LastID, HasMore: hasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileListRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6714,13 +6500,11 @@ func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, k // ContainerFileRetrieve retrieves a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6729,15 +6513,15 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6765,7 +6549,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileRetrieveRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6774,7 +6558,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6809,9 +6593,7 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex Path: fileResp.Path, Source: fileResp.Source, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileRetrieveRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6832,13 +6614,11 @@ func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContex // ContainerFileContent retrieves the content of a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6847,15 +6627,15 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6883,7 +6663,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileContentRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6900,7 +6680,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } content := append([]byte(nil), body...) @@ -6909,9 +6689,7 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext Content: content, ContentType: contentType, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileContentRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -6935,13 +6713,11 @@ func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext // ContainerFileDelete deletes a file from a container via OpenAI's API. func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if len(keys) == 0 { if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess { keys = []schemas.Key{{}} } else { - return nil, providerUtils.NewBifrostOperationError("no keys provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("no keys provided", nil) } } @@ -6950,15 +6726,15 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if request == nil { - return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil) } if request.ContainerID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil) } if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil) } var lastErr *schemas.BifrostError @@ -6986,7 +6762,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, } if resp.StatusCode() >= 400 { - lastErr = ParseOpenAIError(resp, schemas.ContainerFileDeleteRequest, providerName, "") + lastErr = ParseOpenAIError(resp) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -6995,7 +6771,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, // Decode response body (handles content-encoding like gzip) responseBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) continue @@ -7022,9 +6798,7 @@ func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, Object: deleteResp.Object, Deleted: deleteResp.Deleted, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - RequestType: schemas.ContainerFileDeleteRequest, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), }, } @@ -7093,7 +6867,7 @@ func (provider *OpenAIProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } // Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body. @@ -7109,9 +6883,6 @@ func (provider *OpenAIProvider) Passthrough( Body: body, } - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = req.Model - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -7179,9 +6950,9 @@ func (provider *OpenAIProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } headers := make(map[string]string) @@ -7195,9 +6966,7 @@ func (provider *OpenAIProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Wrap reader with idle timeout to detect stalled streams. @@ -7206,11 +6975,7 @@ func (provider *OpenAIProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields) } @@ -7218,11 +6983,12 @@ func (provider *OpenAIProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -7271,7 +7037,7 @@ func (provider *OpenAIProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } diff --git a/core/providers/openai/openai_test.go b/core/providers/openai/openai_test.go index c37040ce62..d2173e1ac7 100644 --- a/core/providers/openai/openai_test.go +++ b/core/providers/openai/openai_test.go @@ -46,69 +46,69 @@ func TestOpenAI(t *testing.T) { ChatAudioModel: "gpt-4o-mini-audio-preview", PassthroughModel: "gpt-4o", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - WebSearchTool: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - FileBase64: true, - FileURL: true, - CompleteEnd2End: true, - SpeechSynthesis: true, - SpeechSynthesisStream: true, - Transcription: true, - TranscriptionStream: true, - Embedding: true, - Reasoning: true, - ListModels: true, - ImageGeneration: true, - ImageGenerationStream: true, - ImageEdit: true, - ImageEditStream: true, - ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations - VideoGeneration: false, // disabled for now because of long running operations - VideoRetrieve: false, - VideoRemix: false, - VideoDownload: false, - VideoList: false, - VideoDelete: false, - BatchCreate: true, - BatchList: true, - BatchRetrieve: true, - BatchCancel: true, - BatchResults: true, - FileUpload: true, - FileList: true, - FileRetrieve: true, - FileDelete: true, - FileContent: true, - FileBatchInput: true, - CountTokens: true, - ChatAudio: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - ContainerCreate: true, - ContainerList: true, - ContainerRetrieve: true, - ContainerDelete: true, - ContainerFileCreate: true, - ContainerFileList: true, - ContainerFileRetrieve: true, - ContainerFileContent: true, - ContainerFileDelete: true, - PromptCaching: true, - PassthroughAPI: true, - WebSocketResponses: true, - Realtime: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + WebSearchTool: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + FileBase64: true, + FileURL: true, + CompleteEnd2End: true, + SpeechSynthesis: true, + SpeechSynthesisStream: true, + Transcription: true, + TranscriptionStream: true, + Embedding: true, + Reasoning: true, + ListModels: true, + ImageGeneration: true, + ImageGenerationStream: true, + ImageEdit: true, + ImageEditStream: true, + ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations + VideoGeneration: false, // disabled for now because of long running operations + VideoRetrieve: false, + VideoRemix: false, + VideoDownload: false, + VideoList: false, + VideoDelete: false, + BatchCreate: true, + BatchList: true, + BatchRetrieve: true, + BatchCancel: true, + BatchResults: true, + FileUpload: true, + FileList: true, + FileRetrieve: true, + FileDelete: true, + FileContent: true, + FileBatchInput: true, + CountTokens: true, + ChatAudio: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + ContainerCreate: true, + ContainerList: true, + ContainerRetrieve: true, + ContainerDelete: true, + ContainerFileCreate: true, + ContainerFileList: true, + ContainerFileRetrieve: true, + ContainerFileContent: true, + ContainerFileDelete: true, + PromptCaching: true, + PassthroughAPI: true, + WebSocketResponses: true, + Realtime: false, }, RealtimeModel: "gpt-4o-realtime-preview", } diff --git a/core/providers/openai/realtime.go b/core/providers/openai/realtime.go index b73db4ea24..8c88382297 100644 --- a/core/providers/openai/realtime.go +++ b/core/providers/openai/realtime.go @@ -1,13 +1,17 @@ package openai import ( + "bytes" "encoding/json" "fmt" + "mime/multipart" + "net/http" "net/url" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" ) // SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API. @@ -28,7 +32,6 @@ func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model stri func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + key.Value.GetValue(), - "OpenAI-Beta": "realtime=v1", } for k, v := range provider.networkConfig.ExtraHeaders { headers[k] = v @@ -36,6 +39,380 @@ func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]stri return headers } +// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange. +func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool { + return true +} + +// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls. +func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + model string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + path := "/v1/realtime/calls" + if session == nil && strings.TrimSpace(model) != "" { + path += "?model=" + url.QueryEscape(model) + } + return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session) +} + +// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime. +// Same multipart format but targets the legacy endpoint with model in the URL. +func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + sdp string, + session json.RawMessage, + model string, +) (string, *schemas.BifrostError) { + return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session) +} + +// exchangeWebRTCSDP is the shared multipart SDP exchange implementation. +// Builds a multipart body with sdp + optional session, POSTs to the given path. +func (provider *OpenAIProvider) exchangeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + path string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + bodyBuf := &bytes.Buffer{} + writer := multipart.NewWriter(bodyBuf) + if err := writer.WriteField("sdp", sdp); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err) + } + if session != nil { + if err := writer.WriteField("session", string(session)); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err) + } + } + if err := writer.Close(); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err) + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType(writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+key.Value.GetValue()) + for k, v := range provider.networkConfig.ExtraHeaders { + req.Header.Set(k, v) + } + if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil { + if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" { + req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK) + } + } + req.SetBody(bodyBuf.Bytes()) + + _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return "", bifrostErr + } + + answerBody := resp.Body() + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody) + } + + return string(answerBody), nil +} + +func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(fasthttp.StatusBadGateway), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("upstream_connection_error"), + Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider.GetProviderKey(), + }, + } + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostErr.ExtraFields.RawResponse = map[string]any{ + "status": statusCode, + "body": string(body), + } + } + return bifrostErr +} + +func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + }, + } + if err != nil { + bifrostErr.Error.Error = err + } + return bifrostErr +} + +func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + if event == nil { + return false + } + switch event.Type { + case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted: + return true + default: + return false + } +} + +func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string { + return "oai-events" +} + +func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string { + return "realtime" +} + +func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + switch eventType { + case schemas.RTEventResponseTextDelta, + schemas.RTEventResponseAudioTransDelta, + schemas.RealtimeEventType("response.output_text.delta"), + schemas.RealtimeEventType("response.output_audio_transcript.delta"): + return true + default: + return false + } +} + +// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns +// the native OpenAI response body unchanged. +func (provider *OpenAIProvider) CreateRealtimeClientSecret( + ctx *schemas.BifrostContext, + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, + rawRequest json.RawMessage, +) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil { + return nil, err + } + + normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType) + if bifrostErr != nil { + return nil, bifrostErr + } + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + for k, v := range provider.realtimeSessionHeaders(key, endpointType) { + req.Header.Set(k, v) + } + req.SetBody(normalizedBody) + + latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return nil, bifrostErr + } + + headers := providerUtils.ExtractProviderResponseHeaders(resp) + ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers) + + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return nil, ParseOpenAIError(resp) + } + + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) + } + for k := range headers { + if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { + delete(headers, k) + } + } + + out := &schemas.BifrostPassthroughResponse{ + StatusCode: resp.StatusCode(), + Headers: headers, + Body: body, + } + out.ExtraFields.Provider = provider.GetProviderKey() + out.ExtraFields.OriginalModelRequested = requestedModel + out.ExtraFields.RequestType = schemas.RealtimeRequest + out.ExtraFields.Latency = latency.Milliseconds() + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields) + } + + return out, nil +} + +func normalizeRealtimeClientSecretRequest( + rawRequest json.RawMessage, + defaultProvider schemas.ModelProvider, + endpointType schemas.RealtimeSessionEndpointType, +) ([]byte, string, *schemas.BifrostError) { + root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest) + if bifrostErr != nil { + return nil, "", bifrostErr + } + + modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root) + if bifrostErr != nil { + return nil, "", bifrostErr + } + providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider) + if normalizedModel == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + if providerKey == "" { + providerKey = defaultProvider + } + if providerKey == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil) + } + + if endpointType == schemas.RealtimeSessionEndpointSessions { + return normalizeRealtimeSessionsRequest(root, normalizedModel) + } + + return normalizeRealtimeClientSecretsRequest(root, normalizedModel) +} + +func normalizeRealtimeClientSecretsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + session := map[string]json.RawMessage{} + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + session["model"] = modelJSON + if _, ok := session["type"]; !ok { + typeJSON, marshalErr := json.Marshal("realtime") + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr) + } + session["type"] = typeJSON + } + delete(root, "model") + + sessionJSON, marshalErr := json.Marshal(session) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr) + } + root["session"] = sessionJSON + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func normalizeRealtimeSessionsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + session := map[string]json.RawMessage{} + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + for key, value := range session { + if _, exists := root[key]; !exists { + root[key] = value + } + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + root["model"] = modelJSON + delete(root, "session") + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func (provider *OpenAIProvider) realtimeSessionHeaders( + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, +) map[string]string { + headers := map[string]string{ + "Authorization": "Bearer " + key.Value.GetValue(), + } + if endpointType == schemas.RealtimeSessionEndpointSessions { + headers["OpenAI-Beta"] = "realtime=v1" + } + for k, v := range provider.networkConfig.ExtraHeaders { + headers[k] = v + } + return headers +} + +func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string { + if endpointType == schemas.RealtimeSessionEndpointSessions { + return "/v1/realtime/sessions" + } + return "/v1/realtime/client_secrets" +} + +func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: schemas.OpenAI, + }, + } +} + // openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event. type openAIRealtimeEvent struct { Type string `json:"type"` @@ -44,15 +421,17 @@ type openAIRealtimeEvent struct { Conversation json.RawMessage `json:"conversation,omitempty"` Item json.RawMessage `json:"item,omitempty"` Response json.RawMessage `json:"response,omitempty"` + Part json.RawMessage `json:"part,omitempty"` Delta string `json:"delta,omitempty"` Audio string `json:"audio,omitempty"` Transcript string `json:"transcript,omitempty"` Text string `json:"text,omitempty"` Error json.RawMessage `json:"error,omitempty"` ItemID string `json:"item_id,omitempty"` - OutputIndex int `json:"output_index,omitempty"` - ContentIndex int `json:"content_index,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` ResponseID string `json:"response_id,omitempty"` + AudioEndMS *int `json:"audio_end_ms,omitempty"` PreviousItemID string `json:"previous_item_id,omitempty"` } @@ -105,6 +484,17 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes EventID: raw.EventID, RawData: providerEvent, } + setRealtimeExtraParam(event, "item_id", raw.ItemID) + setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID) + setRealtimeExtraParam(event, "output_index", raw.OutputIndex) + setRealtimeExtraParam(event, "content_index", raw.ContentIndex) + setRealtimeExtraParam(event, "response_id", raw.ResponseID) + setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS) + setRealtimeExtraParam(event, "transcript", raw.Transcript) + setRealtimeExtraParam(event, "text", raw.Text) + setRealtimeExtraParam(event, "conversation", raw.Conversation) + setRealtimeExtraParam(event, "response", raw.Response) + setRealtimeExtraParam(event, "part", raw.Part) switch { case raw.Session != nil: @@ -123,8 +513,10 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes OutputAudioType: sess.OutputAudioType, Tools: sess.Tools, } + if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 { + event.Session.ExtraParams = extra + } } - case raw.Item != nil: var item openAIRealtimeItem if err := json.Unmarshal(raw.Item, &item); err == nil { @@ -139,6 +531,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Arguments: item.Arguments, Output: item.Output, } + if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 { + event.Item.ExtraParams = extra + } } case raw.Error != nil: @@ -150,6 +545,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Message: rtErr.Message, Param: rtErr.Param, } + if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 { + event.Error.ExtraParams = extra + } } } @@ -159,8 +557,8 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Audio: raw.Audio, Transcript: raw.Transcript, ItemID: raw.ItemID, - OutputIdx: &raw.OutputIndex, - ContentIdx: &raw.ContentIndex, + OutputIdx: raw.OutputIndex, + ContentIdx: raw.ContentIndex, ResponseID: raw.ResponseID, } if raw.Delta != "" { @@ -175,19 +573,19 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes // ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON. func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - out := map[string]interface{}{ "type": string(bifrostEvent.Type), } if bifrostEvent.EventID != "" { out["event_id"] = bifrostEvent.EventID } + mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams) if bifrostEvent.Session != nil { sess := map[string]interface{}{} + if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate { + sess["id"] = bifrostEvent.Session.ID + } if bifrostEvent.Session.Model != "" { sess["model"] = bifrostEvent.Session.Model } @@ -218,6 +616,7 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Session.Tools != nil { sess["tools"] = bifrostEvent.Session.Tools } + mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type) out["session"] = sess } @@ -231,6 +630,9 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Role != "" { item["role"] = bifrostEvent.Item.Role } + if bifrostEvent.Item.Status != "" { + item["status"] = bifrostEvent.Item.Status + } if bifrostEvent.Item.Content != nil { item["content"] = bifrostEvent.Item.Content } @@ -246,9 +648,28 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Output != "" { item["output"] = bifrostEvent.Item.Output } + mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams) out["item"] = item } + if bifrostEvent.Error != nil { + rtErr := map[string]interface{}{} + if bifrostEvent.Error.Type != "" { + rtErr["type"] = bifrostEvent.Error.Type + } + if bifrostEvent.Error.Code != "" { + rtErr["code"] = bifrostEvent.Error.Code + } + if bifrostEvent.Error.Message != "" { + rtErr["message"] = bifrostEvent.Error.Message + } + if bifrostEvent.Error.Param != "" { + rtErr["param"] = bifrostEvent.Error.Param + } + mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams) + out["error"] = rtErr + } + if bifrostEvent.Delta != nil { if bifrostEvent.Delta.Text != "" { out["delta"] = bifrostEvent.Delta.Text @@ -259,16 +680,16 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Delta.Transcript != "" { out["transcript"] = bifrostEvent.Delta.Transcript } - if bifrostEvent.Delta.ItemID != "" { + if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") { out["item_id"] = bifrostEvent.Delta.ItemID } - if bifrostEvent.Delta.OutputIdx != nil { + if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") { out["output_index"] = *bifrostEvent.Delta.OutputIdx } - if bifrostEvent.Delta.ContentIdx != nil { + if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") { out["content_index"] = *bifrostEvent.Delta.ContentIdx } - if bifrostEvent.Delta.ResponseID != "" { + if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") { out["response_id"] = bifrostEvent.Delta.ResponseID } } @@ -276,11 +697,269 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi return providerUtils.MarshalSorted(out) } +func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) { + filtered := params + if eventType == schemas.RTEventSessionUpdate && len(params) > 0 { + filtered = make(map[string]json.RawMessage, len(params)) + for key, value := range params { + switch key { + case "id", "object", "expires_at", "client_secret": + continue + default: + filtered[key] = value + } + } + } + mergeRealtimeExtraParams(out, filtered) +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: parsed.Response.Usage.TotalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil { + return nil + } + + content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output) + toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output) + if content == "" && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if content != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls} + } + + return message +} + +type openAIRealtimeResponseDoneEnvelope struct { + Response struct { + Output []openAIRealtimeResponseDoneOutput `json:"output"` + Usage *openAIRealtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type openAIRealtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []openAIRealtimeResponseDoneBlock `json:"content"` +} + +type openAIRealtimeResponseDoneBlock struct { + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type openAIRealtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type openAIRealtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type openAIRealtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string { + var sb strings.Builder + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + sb.WriteString(block.Text) + case strings.TrimSpace(block.Transcript) != "": + sb.WriteString(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + sb.WriteString(block.Refusal) + } + } + } + return strings.TrimSpace(sb.String()) +} + +func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) { + if event == nil || key == "" || value == nil { + return + } + + switch v := value.(type) { + case string: + if v == "" { + return + } + case *int: + if v == nil { + return + } + case json.RawMessage: + if len(v) == 0 || string(v) == "null" { + return + } + } + + raw, err := json.Marshal(value) + if err != nil { + return + } + if event.ExtraParams == nil { + event.ExtraParams = make(map[string]json.RawMessage) + } + event.ExtraParams[key] = raw +} + +func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) { + for key, raw := range params { + if len(raw) == 0 { + continue + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + continue + } + out[key] = value + } +} + +func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool { + if params == nil { + return false + } + raw, ok := params[key] + return ok && len(raw) > 0 +} + +func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage { + if len(raw) == 0 { + return nil + } + root := map[string]json.RawMessage{} + if err := json.Unmarshal(raw, &root); err != nil { + return nil + } + for _, key := range knownKeys { + delete(root, key) + } + if len(root) == 0 { + return nil + } + return root +} + func isRealtimeDeltaEvent(eventType string) bool { switch eventType { case "response.text.delta", + "response.output_text.delta", "response.audio.delta", + "response.output_audio.delta", "response.audio_transcript.delta", + "response.output_audio_transcript.delta", "conversation.item.input_audio_transcription.delta": return true } diff --git a/core/providers/openai/realtime_test.go b/core/providers/openai/realtime_test.go new file mode 100644 index 0000000000..6b7f76f98f --- /dev/null +++ b/core/providers/openai/realtime_test.go @@ -0,0 +1,561 @@ +package openai + +import ( + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestNormalizeRealtimeClientSecretRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["model"]; ok { + t.Fatal("top-level model should be removed after normalization") + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeSessionsRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointSessions, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["session"]; ok { + t.Fatal("legacy sessions endpoint should not forward nested session object") + } + if payload["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview") + } + if payload["voice"] != "alloy" { + t.Fatalf("voice = %v, want %q", payload["voice"], "alloy") + } +} + +func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + contentIndex, err := json.Marshal(0) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + audioEndMS, err := json.Marshal(640) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.truncate"), + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "content_index": contentIndex, + "audio_end_ms": audioEndMS, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "conversation.item.truncate" { + t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate") + } + if payload["item_id"] != "item_123" { + t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123") + } + if payload["content_index"] != float64(0) { + t.Fatalf("content_index = %v, want 0", payload["content_index"]) + } + if payload["audio_end_ms"] != float64(640) { + t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"]) + } +} + +func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } + var audioEndMS int + if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil { + t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err) + } + if audioEndMS != 640 { + t.Fatalf("audio_end_ms = %d, want 640", audioEndMS) + } +} + +func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var transcript string + if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil { + t.Fatalf("json.Unmarshal(transcript) error = %v", err) + } + if transcript != "Who are you?" { + t.Fatalf("transcript = %q, want %q", transcript, "Who are you?") + } +} + +func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.output_text.delta", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "response_id":"resp_123", + "delta":"hello" + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + if event.Delta == nil || event.Delta.Text != "hello" { + t.Fatalf("Delta = %+v, want text=hello", event.Delta) + } +} + +func TestShouldStartRealtimeTurn(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + tests := []struct { + name string + event *schemas.BifrostRealtimeEvent + want bool + }{ + { + name: "response create starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate}, + want: true, + }, + { + name: "audio buffer committed starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted}, + want: true, + }, + { + name: "response done does not start turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone}, + want: false, + }, + { + name: "nil event does not start turn", + event: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want { + t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + outputIndex := 0 + contentIndex := 0 + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("response.output_text.delta"), + Delta: &schemas.RealtimeDelta{ + Text: "hello", + ItemID: "item_123", + OutputIdx: &outputIndex, + ContentIdx: &contentIndex, + ResponseID: "resp_123", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "response.output_text.delta" { + t.Fatalf("type = %v, want response.output_text.delta", payload["type"]) + } + if payload["delta"] != "hello" { + t.Fatalf("delta = %v, want hello", payload["delta"]) + } +} + +func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionCreated, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + session, ok := payload["session"].(map[string]any) + if !ok { + t.Fatalf("session = %T, want object", payload["session"]) + } + if session["id"] != "sess_123" { + t.Fatalf("session.id = %v, want sess_123", session["id"]) + } +} + +func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`) + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.retrieved"), + Item: &schemas.RealtimeItem{ + ID: "item_123", + Type: "message", + Role: "user", + Status: "completed", + Content: content, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + item, ok := payload["item"].(map[string]any) + if !ok { + t.Fatalf("item = %T, want object", payload["item"]) + } + if item["status"] != "completed" { + t.Fatalf("item.status = %v, want completed", item["status"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.done", + "event_id":"evt_123", + "response":{ + "id":"resp_123", + "output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}] + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var response map[string]any + if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil { + t.Fatalf("json.Unmarshal(response) error = %v", err) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseDone, + ExtraParams: map[string]json.RawMessage{ + "response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + response, ok := payload["response"].(map[string]any) + if !ok { + t.Fatalf("response = %T, want object", payload["response"]) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.content_part.added", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "part":{ + "type":"text", + "text":"hello" + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var part map[string]any + if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil { + t.Fatalf("json.Unmarshal(part) error = %v", err) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseContentPartAdded, + ExtraParams: map[string]json.RawMessage{ + "part": json.RawMessage(`{"type":"text","text":"hello"}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + part, ok := payload["part"].(map[string]any) + if !ok { + t.Fatalf("part = %T, want object", payload["part"]) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + event, err := schemas.ParseRealtimeEvent([]byte(`{ + "type":"session.update", + "session":{ + "type":"realtime", + "model":"gpt-4o-realtime-preview", + "output_modalities":["text"] + } + }`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + if event.Session == nil { + t.Fatal("expected session to be parsed") + } + var outputModalities []string + if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil { + t.Fatalf("json.Unmarshal(output_modalities) error = %v", err) + } + if len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("output_modalities = %v, want [text]", outputModalities) + } +} + +func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + Model: "gpt-4o-realtime-preview", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "output_modalities": json.RawMessage(`["text"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Type string `json:"type"` + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload.Type != "session.update" { + t.Fatalf("type = %q, want %q", payload.Type, "session.update") + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + outputModalities, ok := payload.Session["output_modalities"].([]any) + if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"]) + } +} + +func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "object": json.RawMessage(`"realtime.session"`), + "expires_at": json.RawMessage(`1774614381`), + "client_secret": json.RawMessage(`{"value":"secret"}`), + "modalities": json.RawMessage(`["text","audio"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + for _, key := range []string{"id", "object", "expires_at", "client_secret"} { + if _, ok := payload.Session[key]; ok { + t.Fatalf("session.%s unexpectedly present in session.update payload", key) + } + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + if payload.Session["model"] != "gpt-realtime" { + t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"]) + } +} diff --git a/core/providers/openai/responses.go b/core/providers/openai/responses.go index 23f59c5155..e8efee4689 100644 --- a/core/providers/openai/responses.go +++ b/core/providers/openai/responses.go @@ -201,9 +201,12 @@ func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Open if req.ResponsesParameters.Reasoning.Effort != nil { // Native field is provided, use it (and clear max_tokens) effort := *req.ResponsesParameters.Reasoning.Effort - // Convert "minimal" to "low" for non-OpenAI providers - if effort == "minimal" { + // Convert "minimal" to "low"; cap "xhigh"/"max" to "high" — OpenAI tops out at high. + switch effort { + case "minimal": req.ResponsesParameters.Reasoning.Effort = schemas.Ptr("low") + case "xhigh", "max": + req.ResponsesParameters.Reasoning.Effort = schemas.Ptr("high") } // Clear max_tokens since OpenAI doesn't use it req.ResponsesParameters.Reasoning.MaxTokens = nil @@ -220,6 +223,11 @@ func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Open req.ResponsesParameters.Reasoning.MaxTokens = nil } + // summary:"none" is Anthropic-specific (maps to display:"omitted"); strip it for OpenAI. + if req.ResponsesParameters.Reasoning.Summary != nil && *req.ResponsesParameters.Reasoning.Summary == "none" { + req.ResponsesParameters.Reasoning.Summary = nil + } + // Handle xAI-specific parameter filtering // Only grok-3-mini supports reasoning_effort if bifrostReq.Provider == schemas.XAI && diff --git a/core/providers/openai/responses_marshal_test.go b/core/providers/openai/responses_marshal_test.go index d9f8616a18..092e1eb813 100644 --- a/core/providers/openai/responses_marshal_test.go +++ b/core/providers/openai/responses_marshal_test.go @@ -523,3 +523,350 @@ func TestOpenAIResponsesRequest_MarshalJSON_RoundTrip(t *testing.T) { } }) } + +// Regression test for multi-turn Anthropic tool_result with array-form content. +// The OpenAI Responses API defines function_call_output.output as a string (see +// https://platform.openai.com/docs/api-reference/responses/create). When an +// Anthropic client sends a tool_result whose content is an array of text blocks, +// Bifrost's Anthropic→Responses translator populates +// ResponsesToolMessageOutputStruct.ResponsesFunctionToolCallOutputBlocks. +// Historically, that array was marshaled verbatim onto the wire, which some +// strict OpenAI-compat upstreams (e.g. Ollama Cloud) reject with an error like +// +// json: cannot unmarshal array into Go struct field ResponsesFunctionCallOutput.output of type string +// +// The outgoing OpenAI Responses request must emit `output` as a string for +// text-only tool outputs. +func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputFlattensTextBlocksToString(t *testing.T) { + outputText := "line1" + callID := "toolu_abc123" + functionName := "read_file" + + input := &OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("Read /tmp/test.txt and tell me what it contains."), + }, + }, + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(callID), + Name: schemas.Ptr(functionName), + Arguments: schemas.Ptr(`{"path":"/tmp/test.txt"}`), + }, + }, + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(callID), + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: schemas.Ptr(outputText), + }, + }, + }, + }, + }, + }, + } + + jsonBytes, err := input.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal OpenAIResponsesRequestInput: %v", err) + } + + var messages []map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &messages); err != nil { + t.Fatalf("Failed to unmarshal marshaled input as array: %v\nraw=%s", err, string(jsonBytes)) + } + + var fcoMsg map[string]interface{} + for _, m := range messages { + if t, ok := m["type"].(string); ok && t == string(schemas.ResponsesMessageTypeFunctionCallOutput) { + fcoMsg = m + break + } + } + if fcoMsg == nil { + t.Fatalf("did not find function_call_output message in marshaled JSON: %s", string(jsonBytes)) + } + + outputVal, ok := fcoMsg["output"] + if !ok { + t.Fatalf("function_call_output message has no `output` field: %s", string(jsonBytes)) + } + + outputStr, isString := outputVal.(string) + if !isString { + t.Fatalf("function_call_output.output must be a string (OpenAI Responses API spec); got %T: %v\nraw=%s", outputVal, outputVal, string(jsonBytes)) + } + if outputStr != outputText { + t.Fatalf("function_call_output.output mismatch: want %q, got %q", outputText, outputStr) + } +} + +// Flattening must concatenate multiple text blocks with newline separators so +// every character from the upstream tool response reaches the model. +func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputConcatenatesMultipleTextBlocks(t *testing.T) { + callID := "toolu_multi" + input := &OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(callID), + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{ + {Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("line1")}, + {Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("line2")}, + }, + }, + }, + }, + }, + } + + jsonBytes, err := input.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + var messages []map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &messages); err != nil { + t.Fatalf("Failed to unmarshal: %v\nraw=%s", err, string(jsonBytes)) + } + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + got, ok := messages[0]["output"].(string) + if !ok { + t.Fatalf("output must be string, got %T", messages[0]["output"]) + } + if want := "line1\nline2"; got != want { + t.Fatalf("flattened output mismatch: want %q, got %q", want, got) + } +} + +// When the tool result contains a non-text block (e.g. an image), flattening is +// unsafe — preserve the array form and let the upstream handle it. This keeps +// the fix scoped to the common text-only case without dropping rich content. +func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputPreservesNonTextBlocks(t *testing.T) { + callID := "toolu_with_image" + imageURL := "https://example.com/screenshot.png" + input := &OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(callID), + Output: &schemas.ResponsesToolMessageOutputStruct{ + ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{ + {Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("here is the screenshot:")}, + { + Type: schemas.ResponsesInputMessageContentBlockTypeImage, + ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: &imageURL, + }, + }, + }, + }, + }, + }, + }, + } + jsonBytes, err := input.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + var messages []map[string]interface{} + if err := sonic.Unmarshal(jsonBytes, &messages); err != nil { + t.Fatalf("Failed to unmarshal: %v\nraw=%s", err, string(jsonBytes)) + } + if _, isString := messages[0]["output"].(string); isString { + t.Fatalf("non-text blocks must not be flattened to string; raw=%s", string(jsonBytes)) + } +} + +// TestOpenAIResponsesRequest_MarshalJSON_StripsAnthropicToolFlags ensures the +// Responses serializer drops the four Anthropic-native tool flags +// (defer_loading, allowed_callers, input_examples, eager_input_streaming) +// along with CacheControl before forwarding to OpenAI — mirroring the Chat +// path's behavior so Anthropic-flavored tools cannot 400 OpenAI via Responses. +func TestOpenAIResponsesRequest_MarshalJSON_StripsAnthropicToolFlags(t *testing.T) { + req := &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("hello"), + }, + }, + }, + }, + ResponsesParameters: schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + { + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("lookup"), + Description: schemas.Ptr("lookup something"), + CacheControl: &schemas.CacheControl{Type: "ephemeral"}, + DeferLoading: schemas.Ptr(true), + AllowedCallers: []string{"direct", "agent"}, + EagerInputStreaming: schemas.Ptr(false), + InputExamples: []schemas.ChatToolInputExample{ + {Input: json.RawMessage(`{"q":"hi"}`)}, + }, + ResponsesToolFunction: &schemas.ResponsesToolFunction{}, + }, + }, + }, + } + + jsonBytes, err := req.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + raw := string(jsonBytes) + + // None of the five Anthropic-only tool keys must survive on the wire. + for _, key := range []string{`"cache_control"`, `"defer_loading"`, `"allowed_callers"`, `"input_examples"`, `"eager_input_streaming"`} { + if strings.Contains(raw, key) { + t.Errorf("OpenAI Responses serializer must strip %s; raw=%s", key, raw) + } + } + // Function tool identity should be preserved. + if !strings.Contains(raw, `"name":"lookup"`) { + t.Errorf("tool identity lost after strip; raw=%s", raw) + } +} + +// TestOpenAIResponsesRequest_MarshalJSON_DropsAnthropicOnlyToolTypes verifies +// that Anthropic-only tool types (web_fetch, memory) are dropped entirely when +// serializing for OpenAI Responses. Per OpenAI's OpenAPI spec the Responses +// Tool discriminator union does not include web_fetch or memory, so forwarding +// them would trigger a 400 schema-validation error. Mirrors the Chat path's +// isAnthropicServerToolShape drop behavior. +func TestOpenAIResponsesRequest_MarshalJSON_DropsAnthropicOnlyToolTypes(t *testing.T) { + req := &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: schemas.Ptr("hello"), + }, + }, + }, + }, + ResponsesParameters: schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + // Kept: function (OpenAI-native). + { + Type: schemas.ResponsesToolTypeFunction, + Name: schemas.Ptr("keeper_func"), + ResponsesToolFunction: &schemas.ResponsesToolFunction{}, + }, + // Dropped: web_fetch (Anthropic-only). + { + Type: schemas.ResponsesToolTypeWebFetch, + Name: schemas.Ptr("anthropic_webfetch"), + ResponsesToolWebFetch: &schemas.ResponsesToolWebFetch{}, + }, + // Kept: web_search (both support). + { + Type: schemas.ResponsesToolTypeWebSearch, + ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}, + }, + // Dropped: memory (Anthropic-only). + { + Type: schemas.ResponsesToolTypeMemory, + Name: schemas.Ptr("anthropic_memory"), + }, + // Kept: tool_search (both support per OpenAI OpenAPI spec). + { + Type: schemas.ResponsesToolTypeToolSearch, + }, + }, + }, + } + + jsonBytes, err := req.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + raw := string(jsonBytes) + + // Dropped types must not appear on the wire. + for _, dropped := range []string{`"web_fetch"`, `"memory"`, `"anthropic_webfetch"`, `"anthropic_memory"`} { + if strings.Contains(raw, dropped) { + t.Errorf("Anthropic-only tool must be dropped; found %s in raw=%s", dropped, raw) + } + } + // Kept types must still appear. + for _, kept := range []string{`"function"`, `"web_search"`, `"tool_search"`, `"keeper_func"`} { + if !strings.Contains(raw, kept) { + t.Errorf("supported tool %s should be preserved; raw=%s", kept, raw) + } + } + + // Confirm the tools array is present and has exactly 3 entries (2 dropped of 5). + var decoded struct { + Tools []map[string]interface{} `json:"tools"` + } + if err := json.Unmarshal(jsonBytes, &decoded); err != nil { + t.Fatalf("decode failed: %v", err) + } + if len(decoded.Tools) != 3 { + t.Errorf("expected 3 tools after drop (function, web_search, tool_search), got %d; tools=%+v", len(decoded.Tools), decoded.Tools) + } +} + +// TestOpenAIResponsesRequest_MarshalJSON_KeepsAllWhenAllSupported verifies the +// no-reshape fast path: if every tool is OpenAI-compatible with no +// Anthropic-only flags, the tools slice passes through unchanged (no copy, +// no drop). +func TestOpenAIResponsesRequest_MarshalJSON_KeepsAllWhenAllSupported(t *testing.T) { + req := &OpenAIResponsesRequest{ + Model: "gpt-4o", + Input: OpenAIResponsesRequestInput{ + OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{ + { + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr("hi")}, + }, + }, + }, + ResponsesParameters: schemas.ResponsesParameters{ + Tools: []schemas.ResponsesTool{ + {Type: schemas.ResponsesToolTypeFunction, Name: schemas.Ptr("f"), ResponsesToolFunction: &schemas.ResponsesToolFunction{}}, + {Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}}, + {Type: schemas.ResponsesToolTypeCodeInterpreter, ResponsesToolCodeInterpreter: &schemas.ResponsesToolCodeInterpreter{}}, + }, + }, + } + + jsonBytes, err := req.MarshalJSON() + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + var decoded struct { + Tools []map[string]interface{} `json:"tools"` + } + if err := json.Unmarshal(jsonBytes, &decoded); err != nil { + t.Fatalf("decode failed: %v", err) + } + if len(decoded.Tools) != 3 { + t.Errorf("expected 3 tools preserved, got %d", len(decoded.Tools)) + } +} diff --git a/core/providers/openai/text_test.go b/core/providers/openai/text_test.go index b2dd53ee35..71c2f195a0 100644 --- a/core/providers/openai/text_test.go +++ b/core/providers/openai/text_test.go @@ -51,7 +51,6 @@ func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T) func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/openai/transcription.go b/core/providers/openai/transcription.go index 8ab2305b05..8c2bf112a1 100644 --- a/core/providers/openai/transcription.go +++ b/core/providers/openai/transcription.go @@ -54,63 +54,63 @@ func ParseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR } fileWriter, err := writer.CreateFormFile("file", filename) if err != nil { - return utils.NewBifrostOperationError("failed to create form file", err, providerName) + return utils.NewBifrostOperationError("failed to create form file", err) } if _, err := fileWriter.Write(openaiReq.File); err != nil { - return utils.NewBifrostOperationError("failed to write file data", err, providerName) + return utils.NewBifrostOperationError("failed to write file data", err) } // Add model field if err := writer.WriteField("model", openaiReq.Model); err != nil { - return utils.NewBifrostOperationError("failed to write model field", err, providerName) + return utils.NewBifrostOperationError("failed to write model field", err) } // Add optional fields if openaiReq.Language != nil { if err := writer.WriteField("language", *openaiReq.Language); err != nil { - return utils.NewBifrostOperationError("failed to write language field", err, providerName) + return utils.NewBifrostOperationError("failed to write language field", err) } } if openaiReq.Prompt != nil { if err := writer.WriteField("prompt", *openaiReq.Prompt); err != nil { - return utils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return utils.NewBifrostOperationError("failed to write prompt field", err) } } if openaiReq.ResponseFormat != nil { if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil { - return utils.NewBifrostOperationError("failed to write response_format field", err, providerName) + return utils.NewBifrostOperationError("failed to write response_format field", err) } } if openaiReq.Temperature != nil { if err := writer.WriteField("temperature", fmt.Sprintf("%g", *openaiReq.Temperature)); err != nil { - return utils.NewBifrostOperationError("failed to write temperature field", err, providerName) + return utils.NewBifrostOperationError("failed to write temperature field", err) } } for _, granularity := range openaiReq.TimestampGranularities { if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil { - return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err, providerName) + return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err) } } for _, include := range openaiReq.Include { if err := writer.WriteField("include[]", include); err != nil { - return utils.NewBifrostOperationError("failed to write include field", err, providerName) + return utils.NewBifrostOperationError("failed to write include field", err) } } if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { - return utils.NewBifrostOperationError("failed to write stream field", err, providerName) + return utils.NewBifrostOperationError("failed to write stream field", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return utils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return utils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openai/types.go b/core/providers/openai/types.go index 89de4e1e66..e2eab5245a 100644 --- a/core/providers/openai/types.go +++ b/core/providers/openai/types.go @@ -4,10 +4,11 @@ import ( "encoding/json" "errors" "fmt" + "strings" "github.com/bytedance/sonic" - "github.com/maximhq/bifrost/core/schemas" providerUtils "github.com/maximhq/bifrost/core/providers/utils" + "github.com/maximhq/bifrost/core/schemas" ) const MinMaxCompletionTokens = 16 @@ -82,7 +83,7 @@ type OpenAIChatRequest struct { // PromptCacheIsolationKey is the Fireworks chat-completions field for cache isolation. PromptCacheIsolationKey *string `json:"prompt_cache_isolation_key,omitempty"` - //NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. + // NOTE: MaxCompletionTokens is a new replacement for max_tokens but some providers still use max_tokens. // This Field is populated only for such providers and is NOT to be used externally. MaxTokens *int `json:"max_tokens,omitempty"` @@ -184,27 +185,42 @@ func (req *OpenAIChatRequest) MarshalJSON() ([]byte, error) { processedMessages = req.Messages } - // Process tools if needed + // Process tools if needed. + // On outbound to OpenAI we need to: + // (a) Strip CacheControl (Anthropic-only, existing behavior). + // (b) Drop Anthropic server tools entirely (Function == nil && Custom == nil); + // OpenAI won't accept web_search_20260209 etc. + // (c) Strip Anthropic-native per-tool flags (DeferLoading, AllowedCallers, + // InputExamples, EagerInputStreaming) when they're set on function tools. var processedTools []schemas.ChatTool if len(req.Tools) > 0 { - needsToolCopy := false + needsToolChange := false for _, tool := range req.Tools { - if tool.CacheControl != nil { - needsToolCopy = true + if tool.CacheControl != nil || isAnthropicServerToolShape(tool) || hasAnthropicOnlyToolFlags(tool) { + needsToolChange = true break } } - if needsToolCopy { - processedTools = make([]schemas.ChatTool, len(req.Tools)) - for i, tool := range req.Tools { - if tool.CacheControl != nil { - toolCopy := tool - toolCopy.CacheControl = nil - processedTools[i] = toolCopy - } else { - processedTools[i] = tool + if needsToolChange { + processedTools = make([]schemas.ChatTool, 0, len(req.Tools)) + for _, tool := range req.Tools { + // Drop Anthropic server tools (no function/custom payload). + // OpenAI would reject the request if we forwarded them. + if isAnthropicServerToolShape(tool) { + continue } + if tool.CacheControl == nil && !hasAnthropicOnlyToolFlags(tool) { + processedTools = append(processedTools, tool) + continue + } + toolCopy := tool + toolCopy.CacheControl = nil + toolCopy.DeferLoading = nil + toolCopy.AllowedCallers = nil + toolCopy.InputExamples = nil + toolCopy.EagerInputStreaming = nil + processedTools = append(processedTools, toolCopy) } } else { processedTools = req.Tools @@ -427,8 +443,23 @@ func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { } } - // Strip CacheControl and FileType from tool message output blocks if needed - if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + // Collapse text-only tool output blocks into a single string. + // OpenAI's Responses API defines function_call_output.output as + // a string; Anthropic's multi-turn tool_result content arrives + // as an array of content blocks and has to be flattened here. + // Strict upstream implementations (e.g. Ollama Cloud) return a + // 400 otherwise. + if msg.ResponsesToolMessage.Output != nil && + msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil && + isFunctionCallOutputBlocksFlattenable(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) { + flattened := flattenFunctionCallOutputBlocks(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) + outputCopy := *msg.ResponsesToolMessage.Output + outputCopy.ResponsesToolCallOutputStr = &flattened + outputCopy.ResponsesFunctionToolCallOutputBlocks = nil + toolMsgCopy.Output = &outputCopy + toolMsgModified = true + } else if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + // Strip CacheControl and FileType from tool message output blocks if needed hasToolModification := false for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { if block.CacheControl != nil || block.Citations != nil || (block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileType != nil) { @@ -473,6 +504,52 @@ func (r *OpenAIResponsesRequestInput) MarshalJSON() ([]byte, error) { } // Helper function to check if a chat message has any CacheControl fields or FileType in file blocks +// isAnthropicServerToolShape reports whether the tool carries the Anthropic +// server-tool shape (Function and Custom both nil). On outbound to OpenAI, +// these must be dropped — OpenAI doesn't accept tool types like +// web_search_20260209, computer_20251124, mcp_toolset, etc. +func isAnthropicServerToolShape(t schemas.ChatTool) bool { + return t.Function == nil && t.Custom == nil +} + +// hasAnthropicOnlyToolFlags reports whether the tool carries any of the +// Anthropic-native flags that OpenAI would reject (DeferLoading, +// AllowedCallers, InputExamples, EagerInputStreaming). Strip these when +// forwarding to OpenAI. +func hasAnthropicOnlyToolFlags(t schemas.ChatTool) bool { + return t.DeferLoading != nil || + len(t.AllowedCallers) > 0 || + len(t.InputExamples) > 0 || + t.EagerInputStreaming != nil +} + +// hasAnthropicOnlyResponsesToolFlags is the ResponsesTool-typed parallel of +// hasAnthropicOnlyToolFlags. The four flags were promoted onto ResponsesTool +// in core/schemas/responses.go for the Anthropic-via-Responses path; the +// OpenAI Responses serializer must strip them so they don't leak to OpenAI +// and trigger a 400 on unknown fields. +func hasAnthropicOnlyResponsesToolFlags(t schemas.ResponsesTool) bool { + return t.DeferLoading != nil || + len(t.AllowedCallers) > 0 || + len(t.InputExamples) > 0 || + t.EagerInputStreaming != nil +} + +// isAnthropicOnlyResponsesToolType reports whether the tool type exists only +// in Anthropic's taxonomy and is not part of OpenAI's Responses API Tool union +// (per OpenAI's OpenAPI spec component.schemas.Tool, which enumerates function, +// file_search, computer[_use_preview], web_search[_preview], mcp, +// code_interpreter, image_generation, local_shell, custom, tool_search, and +// related shell/namespace/apply_patch variants). Forwarding web_fetch or +// memory to OpenAI guarantees a 400 on schema discriminator validation, so +// these get dropped in the Responses→OpenAI serializer — mirroring the Chat +// path's isAnthropicServerToolShape drop behavior for schema parity across +// both endpoints. +func isAnthropicOnlyResponsesToolType(t schemas.ResponsesTool) bool { + return t.Type == schemas.ResponsesToolTypeWebFetch || + t.Type == schemas.ResponsesToolTypeMemory +} + func hasFieldsToStripInChatMessage(msg OpenAIMessage) bool { if msg.Content != nil && msg.Content.ContentBlocks != nil { for _, block := range msg.Content.ContentBlocks { @@ -527,6 +604,12 @@ func hasFieldsToStripInResponsesMessage(msg schemas.ResponsesMessage) bool { } // Check output blocks if msg.ResponsesToolMessage.Output != nil && msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks != nil { + // Text-only block arrays must be flattened to a string — OpenAI's + // Responses API defines function_call_output.output as a string + // and strict upstreams reject the array form. + if isFunctionCallOutputBlocksFlattenable(msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks) { + return true + } for _, block := range msg.ResponsesToolMessage.Output.ResponsesFunctionToolCallOutputBlocks { if block.CacheControl != nil { return true @@ -540,6 +623,41 @@ func hasFieldsToStripInResponsesMessage(msg schemas.ResponsesMessage) bool { return false } +// isFunctionCallOutputBlocksFlattenable reports whether a function_call_output +// content block slice contains only text blocks and can therefore be collapsed +// into a single string for the OpenAI Responses API wire format. +func isFunctionCallOutputBlocksFlattenable(blocks []schemas.ResponsesMessageContentBlock) bool { + if len(blocks) == 0 { + return false + } + for _, block := range blocks { + if block.Type != schemas.ResponsesInputMessageContentBlockTypeText && + block.Type != schemas.ResponsesOutputMessageContentTypeText { + return false + } + if block.Text == nil { + return false + } + } + return true +} + +// flattenFunctionCallOutputBlocks concatenates the text of every block in the +// slice. Callers must first verify flattenability via +// isFunctionCallOutputBlocksFlattenable. +func flattenFunctionCallOutputBlocks(blocks []schemas.ResponsesMessageContentBlock) string { + var b strings.Builder + for i, block := range blocks { + if i > 0 { + b.WriteByte('\n') + } + if block.Text != nil { + b.WriteString(*block.Text) + } + } + return b.String() +} + // filterSupportedAnnotations filters out unsupported (non-OpenAI native) citation types // OpenAI supports: file_citation, url_citation, container_file_citation, file_path func filterSupportedAnnotations(annotations []schemas.ResponsesOutputMessageContentTextAnnotation) []schemas.ResponsesOutputMessageContentTextAnnotation { @@ -604,27 +722,45 @@ func (resp *OpenAIResponsesRequest) MarshalJSON() ([]byte, error) { return nil, err } - // Process tools if needed + // Process tools if needed. + // Mirrors the Chat path (see ChatRequest.MarshalJSON) so the same + // Anthropic-flavored tool payload doesn't leak via the Responses serializer: + // (a) Drop Anthropic-only tool TYPES entirely (web_fetch, memory) since + // OpenAI's Responses Tool union doesn't include them — forwarding + // would 400 on the discriminator. + // (b) Strip CacheControl (Anthropic-only schema field). + // (c) Strip the four Anthropic-native per-tool flags (DeferLoading, + // AllowedCallers, InputExamples, EagerInputStreaming). var processedTools []schemas.ResponsesTool if len(resp.Tools) > 0 { - needsToolCopy := false + needsReshape := false for _, tool := range resp.Tools { - if tool.CacheControl != nil { - needsToolCopy = true + if isAnthropicOnlyResponsesToolType(tool) || + tool.CacheControl != nil || + hasAnthropicOnlyResponsesToolFlags(tool) { + needsReshape = true break } } - if needsToolCopy { - processedTools = make([]schemas.ResponsesTool, len(resp.Tools)) - for i, tool := range resp.Tools { - if tool.CacheControl != nil { - toolCopy := tool - toolCopy.CacheControl = nil - processedTools[i] = toolCopy - } else { - processedTools[i] = tool + if needsReshape { + processedTools = make([]schemas.ResponsesTool, 0, len(resp.Tools)) + for _, tool := range resp.Tools { + if isAnthropicOnlyResponsesToolType(tool) { + // Drop — OpenAI Responses has no web_fetch or memory. + continue + } + if tool.CacheControl == nil && !hasAnthropicOnlyResponsesToolFlags(tool) { + processedTools = append(processedTools, tool) + continue } + toolCopy := tool + toolCopy.CacheControl = nil + toolCopy.DeferLoading = nil + toolCopy.AllowedCallers = nil + toolCopy.InputExamples = nil + toolCopy.EagerInputStreaming = nil + processedTools = append(processedTools, toolCopy) } } else { processedTools = resp.Tools diff --git a/core/providers/openai/videos.go b/core/providers/openai/videos.go index aa4052e029..512306b7c7 100644 --- a/core/providers/openai/videos.go +++ b/core/providers/openai/videos.go @@ -132,30 +132,30 @@ func (req *OpenAIVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *sc func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIVideoGenerationRequest, providerName schemas.ModelProvider) *schemas.BifrostError { // Add prompt field (required) if openaiReq.Prompt == "" { - return providerUtils.NewBifrostOperationError("prompt is required", nil, providerName) + return providerUtils.NewBifrostOperationError("prompt is required", nil) } if err := writer.WriteField("prompt", openaiReq.Prompt); err != nil { - return providerUtils.NewBifrostOperationError("failed to write prompt field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write prompt field", err) } // Add optional model field if openaiReq.Model != "" { if err := writer.WriteField("model", openaiReq.Model); err != nil { - return providerUtils.NewBifrostOperationError("failed to write model field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write model field", err) } } // Add optional seconds field if openaiReq.Seconds != nil { if err := writer.WriteField("seconds", *openaiReq.Seconds); err != nil { - return providerUtils.NewBifrostOperationError("failed to write seconds field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write seconds field", err) } } // Add optional size field if openaiReq.Size != "" { if err := writer.WriteField("size", openaiReq.Size); err != nil { - return providerUtils.NewBifrostOperationError("failed to write size field", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write size field", err) } } @@ -196,16 +196,16 @@ func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, opena "Content-Type": {mimeType}, }) if err != nil { - return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err, providerName) + return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err) } if _, err := part.Write(openaiReq.InputReference); err != nil { - return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err, providerName) + return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err) } } // Close the multipart writer if err := writer.Close(); err != nil { - return providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } return nil diff --git a/core/providers/openrouter/openrouter.go b/core/providers/openrouter/openrouter.go index 799ea4b5ab..63ae8f48e4 100644 --- a/core/providers/openrouter/openrouter.go +++ b/core/providers/openrouter/openrouter.go @@ -4,7 +4,6 @@ package openrouter import ( "fmt" "net/http" - "slices" "strings" "time" @@ -95,12 +94,12 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // Check for auth errors (401, 403) statusCode := resp.StatusCode() if statusCode == fasthttp.StatusUnauthorized || statusCode == fasthttp.StatusForbidden { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } // Any 4xx/5xx error indicates the key might be invalid if statusCode >= 400 { - return openai.ParseOpenAIError(resp, schemas.ListModelsRequest, provider.GetProviderKey(), "") + return openai.ParseOpenAIError(resp) } return nil @@ -109,8 +108,6 @@ func (provider *OpenRouterProvider) validateKey(ctx *schemas.BifrostContext, key // listModelsByKey performs a list models request for a single key. // Returns the response and latency, or an error if the request fails. func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - // Validate the key first using /v1/auth/key (only during provider add/update). // OpenRouter's /v1/models doesn't require auth, so we need this extra check. shouldValidate := false @@ -158,7 +155,7 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, // Continue with empty response; allowed models will be backfilled below. modelsFetched = false } else { - bifrostErr := openai.ParseOpenAIError(resp, schemas.ListModelsRequest, providerName, "") + bifrostErr := openai.ParseOpenAIError(resp) return nil, bifrostErr } } @@ -185,45 +182,62 @@ func (provider *OpenRouterProvider) listModelsByKey(ctx *schemas.BifrostContext, } } - // Filter by key.Models - allowedModels := key.Models - blacklistedModels := key.BlacklistedModels + // OpenRouter model IDs in the API response do NOT include the "openrouter/" prefix + // (e.g. the API returns "openai/gpt-4", not "openrouter/openai/gpt-4"). + // Users may supply allowedModels / aliases with or without the prefix, so we + // normalize both by stripping it before feeding into the shared pipeline. providerPrefix := string(schemas.OpenRouter) + "/" + stripPrefix := func(s string) string { + if strings.HasPrefix(strings.ToLower(s), strings.ToLower(providerPrefix)) { + return s[len(providerPrefix):] + } + return s + } + + normalizedAllowed := make(schemas.WhiteList, 0, len(key.Models)) + for _, m := range key.Models { + normalizedAllowed = append(normalizedAllowed, stripPrefix(m)) + } + normalizedBlacklist := make(schemas.BlackList, 0, len(key.BlacklistedModels)) + for _, m := range key.BlacklistedModels { + normalizedBlacklist = append(normalizedBlacklist, stripPrefix(m)) + } + normalizedAliases := make(map[string]string, len(key.Aliases)) + for k, v := range key.Aliases { + normalizedAliases[stripPrefix(k)] = stripPrefix(v) + } + + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: normalizedAllowed, + BlacklistedModels: normalizedBlacklist, + Aliases: normalizedAliases, + Unfiltered: request.Unfiltered, + ProviderKey: schemas.OpenRouter, + MatchFns: providerUtils.DefaultMatchFns(), + } - if !request.Unfiltered && len(allowedModels) > 0 { + if pipeline.ShouldEarlyExit() { + openrouterResponse.Data = make([]schemas.Model, 0) + } else { + included := make(map[string]bool) filteredData := make([]schemas.Model, 0, len(openrouterResponse.Data)) - includedModels := make(map[string]bool) for i := range openrouterResponse.Data { + // rawID has no "openrouter/" prefix — e.g. "openai/gpt-4" rawID := openrouterResponse.Data[i].ID - if !(slices.Contains(allowedModels, rawID) || slices.Contains(allowedModels, providerPrefix+rawID)) { - continue - } - if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { - continue - } - openrouterResponse.Data[i].ID = providerPrefix + rawID - filteredData = append(filteredData, openrouterResponse.Data[i]) - includedModels[rawID] = true - } - // Backfill allowed models not in the API response - for _, allowedModel := range allowedModels { - rawID := strings.TrimPrefix(allowedModel, providerPrefix) - if slices.Contains(blacklistedModels, rawID) || slices.Contains(blacklistedModels, providerPrefix+rawID) { - continue - } - if !includedModels[rawID] { - filteredData = append(filteredData, schemas.Model{ - ID: providerPrefix + rawID, - Name: schemas.Ptr(rawID), - }) - includedModels[rawID] = true // avoid duplicate backfill + for _, result := range pipeline.FilterModel(rawID) { + entry := openrouterResponse.Data[i] + entry.ID = providerPrefix + result.ResolvedID + if result.AliasValue != "" { + entry.Alias = schemas.Ptr(result.AliasValue) + } else { + entry.Alias = nil + } + filteredData = append(filteredData, entry) + included[strings.ToLower(result.ResolvedID)] = true } } + filteredData = append(filteredData, pipeline.BackfillModels(included)...) openrouterResponse.Data = filteredData - } else { - for i := range openrouterResponse.Data { - openrouterResponse.Data[i].ID = providerPrefix + openrouterResponse.Data[i].ID - } } openrouterResponse.ExtraFields.Latency = latency.Milliseconds() diff --git a/core/providers/openrouter/openrouter_test.go b/core/providers/openrouter/openrouter_test.go index d908d4d950..5a6a38ce3f 100644 --- a/core/providers/openrouter/openrouter_test.go +++ b/core/providers/openrouter/openrouter_test.go @@ -31,26 +31,26 @@ func TestOpenRouter(t *testing.T) { EmbeddingModel: "qwen/qwen3-embedding-4b", ReasoningModel: "openai/gpt-oss-120b", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: false, // OpenRouter's responses API is in Beta + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: false, // OpenRouter's responses API is in Beta MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, // OpenRouter's responses API is in Beta - ImageBase64: false, // OpenRouter's responses API is in Beta - MultipleImages: false, // OpenRouter's responses API is in Beta - FileBase64: true, - FileURL: true, - CompleteEnd2End: false, // OpenRouter's responses API is in Beta - Reasoning: true, - ListModels: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - Embedding: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, // OpenRouter's responses API is in Beta + ImageBase64: false, // OpenRouter's responses API is in Beta + MultipleImages: false, // OpenRouter's responses API is in Beta + FileBase64: true, + FileURL: true, + CompleteEnd2End: false, // OpenRouter's responses API is in Beta + Reasoning: true, + ListModels: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + Embedding: true, }, } diff --git a/core/providers/parasail/parasail.go b/core/providers/parasail/parasail.go index 0af0e4bba4..e03d891d38 100644 --- a/core/providers/parasail/parasail.go +++ b/core/providers/parasail/parasail.go @@ -145,9 +145,6 @@ func (provider *ParasailProvider) Responses(ctx *schemas.BifrostContext, key sch } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/perplexity/chat.go b/core/providers/perplexity/chat.go index dafe1c615b..403e0d2dd9 100644 --- a/core/providers/perplexity/chat.go +++ b/core/providers/perplexity/chat.go @@ -38,10 +38,14 @@ func ToPerplexityChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) * // Handle reasoning effort mapping if bifrostReq.Params.Reasoning != nil && bifrostReq.Params.Reasoning.Effort != nil { - if *bifrostReq.Params.Reasoning.Effort == "minimal" { + effort := *bifrostReq.Params.Reasoning.Effort + switch effort { + case "minimal": perplexityReq.ReasoningEffort = schemas.Ptr("low") - } else { - perplexityReq.ReasoningEffort = bifrostReq.Params.Reasoning.Effort + case "xhigh", "max": + perplexityReq.ReasoningEffort = schemas.Ptr("high") + default: + perplexityReq.ReasoningEffort = &effort } } @@ -280,8 +284,6 @@ func (response *PerplexityChatResponse) ToBifrostChatResponse(model string) *sch Object: response.Object, Created: response.Created, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: schemas.Perplexity, }, SearchResults: response.SearchResults, Videos: response.Videos, diff --git a/core/providers/perplexity/perplexity.go b/core/providers/perplexity/perplexity.go index 8ceec1fd4c..f0b21ec21d 100644 --- a/core/providers/perplexity/perplexity.go +++ b/core/providers/perplexity/perplexity.go @@ -101,12 +101,12 @@ func (provider *PerplexityProvider) completeRequest(ctx *schemas.BifrostContext, // Handle error response if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))) - return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp, schemas.ChatCompletionRequest, provider.GetProviderKey(), model) + return nil, latency, providerResponseHeaders, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + return nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Read the response body and copy it before releasing the response @@ -141,8 +141,7 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToPerplexityChatCompletionRequest(request), nil - }, - provider.GetProviderKey()) + }) if err != nil { return nil, err } @@ -161,9 +160,6 @@ func (provider *PerplexityProvider) ChatCompletion(ctx *schemas.BifrostContext, bifrostResponse := response.ToBifrostChatResponse(request.Model) // Set ExtraFields - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders @@ -223,9 +219,6 @@ func (provider *PerplexityProvider) Responses(ctx *schemas.BifrostContext, key s } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } diff --git a/core/providers/perplexity/types.go b/core/providers/perplexity/types.go index feef9e0ccb..d5ad5c65f6 100644 --- a/core/providers/perplexity/types.go +++ b/core/providers/perplexity/types.go @@ -4,45 +4,45 @@ import "github.com/maximhq/bifrost/core/schemas" // PerplexityChatRequest represents a Perplexity chat completion request type PerplexityChatRequest struct { - Model string `json:"model"` // Required: Model to use for chat completion - Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects - SearchMode *string `json:"search_mode"` // Required: Search mode - ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) - MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate - Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature - TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling - LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference - SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter - ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images - ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions - SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter - SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter - SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter - LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter - LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter - TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling - Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty - ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response - DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search - EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier - WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options - MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response - Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model - ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls - Stop []string `json:"stop,omitempty"` // Optional: Stop sequences - LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities - TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities - NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results - NumImages *int `json:"num_images,omitempty"` // Optional: Number of images - SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter - ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter - ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter - SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search - StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode - ExtraParams map[string]interface{} `json:"-"` + Model string `json:"model"` // Required: Model to use for chat completion + Messages []schemas.ChatMessage `json:"messages"` // Required: Array of message objects + SearchMode *string `json:"search_mode"` // Required: Search mode + ReasoningEffort *string `json:"reasoning_effort"` // Required: Reasoning effort (low, medium, high) + MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature + TopP *float64 `json:"top_p,omitempty"` // Optional: Top-p sampling + LanguagePreference *string `json:"language_preference,omitempty"` // Optional: Language preference + SearchDomainFilter []string `json:"search_domain_filter,omitempty"` // Optional: Search domain filter + ReturnImages *bool `json:"return_images,omitempty"` // Optional: Return images + ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"` // Optional: Return related questions + SearchRecencyFilter *string `json:"search_recency_filter,omitempty"` // Optional: Search recency filter + SearchAfterDateFilter *string `json:"search_after_date_filter,omitempty"` // Optional: Search after date filter + SearchBeforeDateFilter *string `json:"search_before_date_filter,omitempty"` // Optional: Search before date filter + LastUpdatedAfterFilter *string `json:"last_updated_after_filter,omitempty"` // Optional: Last updated after filter + LastUpdatedBeforeFilter *string `json:"last_updated_before_filter,omitempty"` // Optional: Last updated before filter + TopK *int `json:"top_k,omitempty"` // Optional: Top-k sampling + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty + ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response + DisableSearch *bool `json:"disable_search,omitempty"` // Optional: Disable search + EnableSearchClassifier *bool `json:"enable_search_classifier,omitempty"` // Optional: Enable search classifier + WebSearchOptions []WebSearchOption `json:"web_search_options,omitempty"` // Optional: Web search options + MediaResponse *MediaResponse `json:"media_response,omitempty"` // Optional: Media response + Tools []schemas.ChatTool `json:"tools,omitempty"` // Optional: Tools available for the model + ToolChoice *schemas.ChatToolChoice `json:"tool_choice,omitempty"` // Optional: Whether to call a tool + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Optional: Enable parallel tool calls + Stop []string `json:"stop,omitempty"` // Optional: Stop sequences + LogProbs *bool `json:"logprobs,omitempty"` // Optional: Return log probabilities + TopLogProbs *int `json:"top_logprobs,omitempty"` // Optional: Number of top log probabilities + NumSearchResults *int `json:"num_search_results,omitempty"` // Optional: Number of search results + NumImages *int `json:"num_images,omitempty"` // Optional: Number of images + SearchLanguageFilter []string `json:"search_language_filter,omitempty"` // Optional: Search language filter + ImageFormatFilter []string `json:"image_format_filter,omitempty"` // Optional: Image format filter + ImageDomainFilter []string `json:"image_domain_filter,omitempty"` // Optional: Image domain filter + SafeSearch *bool `json:"safe_search,omitempty"` // Optional: Enable safe search + StreamMode *string `json:"stream_mode,omitempty"` // Optional: Stream mode + ExtraParams map[string]interface{} `json:"-"` } // GetExtraParams implements the RequestBodyWithExtraParams interface diff --git a/core/providers/replicate/errors.go b/core/providers/replicate/errors.go index e7fc2051d0..1575d9ca77 100644 --- a/core/providers/replicate/errors.go +++ b/core/providers/replicate/errors.go @@ -15,9 +15,6 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: replicateErr.Detail, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -28,8 +25,5 @@ func parseReplicateError(body []byte, statusCode int) *schemas.BifrostError { Error: &schemas.ErrorField{ Message: string(body), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/replicate/files.go b/core/providers/replicate/files.go index cdd37a65c5..15ca0e13e8 100644 --- a/core/providers/replicate/files.go +++ b/core/providers/replicate/files.go @@ -30,8 +30,6 @@ func (r *ReplicateFileResponse) ToBifrostFileUploadResponse(providerName schemas Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileUploadRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } @@ -67,8 +65,6 @@ func (r *ReplicateFileResponse) ToBifrostFileRetrieveResponse(providerName schem Status: ToBifrostFileStatus(r), StorageBackend: schemas.FileStorageAPI, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileRetrieveRequest, - Provider: providerName, Latency: latency.Milliseconds(), }, } diff --git a/core/providers/replicate/images.go b/core/providers/replicate/images.go index 4fa7d0dd81..0327b8b5fa 100644 --- a/core/providers/replicate/images.go +++ b/core/providers/replicate/images.go @@ -1,7 +1,6 @@ package replicate import ( - "strconv" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" @@ -28,49 +27,6 @@ var modelInputImageFieldMap = map[string]string{ "black-forest-labs/flux-krea-dev": "image", } -// convertSizeToReplicateFormat converts standard size format (e.g., "1024x1024") to Replicate format. -// Returns (aspectRatio, imageSize) where imageSize is "1k", "2k", "4k" and aspectRatio is one of: -// "1:1", "3:4", "4:3", "9:16", or "16:9". Returns empty strings if unparseable or ratio unrecognised. -func convertSizeToReplicateFormat(size string) (aspectRatio, imageSize string) { - parts := strings.Split(size, "x") - if len(parts) != 2 { - return "", "" - } - - width, err1 := strconv.Atoi(parts[0]) - height, err2 := strconv.Atoi(parts[1]) - if err1 != nil || err2 != nil { - return "", "" - } - - if width <= 0 || height <= 0 { - return "", "" - } - - if width <= 1024 && height <= 1024 { - imageSize = "1K" - } else if width <= 2048 && height <= 2048 { - imageSize = "2K" - } else if width <= 4096 && height <= 4096 { - imageSize = "4K" - } - - ratio := float64(width) / float64(height) - if ratio >= 0.99 && ratio <= 1.01 { - aspectRatio = "1:1" - } else if ratio >= 0.74 && ratio <= 0.76 { - aspectRatio = "3:4" - } else if ratio >= 1.32 && ratio <= 1.34 { - aspectRatio = "4:3" - } else if ratio >= 0.56 && ratio <= 0.57 { - aspectRatio = "9:16" - } else if ratio >= 1.77 && ratio <= 1.78 { - aspectRatio = "16:9" - } - - return aspectRatio, imageSize -} - // ToReplicateImageGenerationInput converts a Bifrost image generation request to Replicate prediction input func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationRequest) *ReplicatePredictionRequest { if bifrostReq == nil || bifrostReq.Input == nil { @@ -85,29 +41,6 @@ func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationR if bifrostReq.Params != nil { params := bifrostReq.Params - // Map InputImages to the appropriate field based on model - if len(params.InputImages) > 0 { - fieldName := getInputImageFieldName(bifrostReq.Model) - - switch fieldName { - case "image_prompt": - // For flux-1.1-pro variants: use first image as image_prompt - input.ImagePrompt = ¶ms.InputImages[0] - - case "input_image": - // For flux-kontext variants: add to ExtraParams as input_image - input.InputImage = ¶ms.InputImages[0] - - case "image": - // For flux-dev variants: use first image as image field - input.Image = ¶ms.InputImages[0] - - case "input_images": - // For all other models: use input_images array - input.InputImages = params.InputImages - } - } - if bifrostReq.Params.N != nil { input.NumberOfImages = bifrostReq.Params.N } @@ -117,7 +50,7 @@ func ToReplicateImageGenerationInput(bifrostReq *schemas.BifrostImageGenerationR } if params.Size != nil { - aspectRatio, imageSize := convertSizeToReplicateFormat(*params.Size) + aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size) _, hasExplicitResolution := params.ExtraParams["resolution"] if params.AspectRatio == nil && aspectRatio != "" { input.AspectRatio = &aspectRatio @@ -191,9 +124,6 @@ func ToBifrostImageGenerationResponse( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } @@ -294,7 +224,7 @@ func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *Rep input.Image = &images[0] case "input_images": - // For all other models: use input_images array + // For all other models: use input_images array (preserves multi-image support) input.InputImages = images } } @@ -309,7 +239,7 @@ func ToReplicateImageEditInput(bifrostReq *schemas.BifrostImageEditRequest) *Rep } if params.Size != nil { - aspectRatio, imageSize := convertSizeToReplicateFormat(*params.Size) + aspectRatio, imageSize := providerUtils.ConvertSizeToAspectRatioAndResolution(*params.Size) _, hasExplicitAspectRatio := params.ExtraParams["aspect_ratio"] _, hasExplicitResolution := params.ExtraParams["resolution"] if aspectRatio != "" && !hasExplicitAspectRatio { diff --git a/core/providers/replicate/models.go b/core/providers/replicate/models.go index 3989628db1..6c0c14dbf7 100644 --- a/core/providers/replicate/models.go +++ b/core/providers/replicate/models.go @@ -1,61 +1,66 @@ package replicate import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) -// ToBifrostListModelsResponse converts Replicate models and deployments to a Bifrost list models response +// ToBifrostListModelsResponse converts Replicate deployments to a Bifrost list models response. +// Replicate model IDs are composite: "{owner}/{name}" (e.g. "stability-ai/stable-diffusion"). func ToBifrostListModelsResponse( deploymentsResponse *ReplicateDeploymentListResponse, providerKey schemas.ModelProvider, - allowedModels []string, - blacklistedModels []string, + allowedModels schemas.WhiteList, + blacklistedModels schemas.BlackList, + aliases map[string]string, unfiltered bool, ) *schemas.BifrostListModelsResponse { bifrostResponse := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } - includedModels := make(map[string]bool) - // Add deployments from /v1/deployments endpoint + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: providerKey, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) + if deploymentsResponse != nil { for _, deployment := range deploymentsResponse.Results { + // Replicate model IDs are composite owner/name deploymentID := deployment.Owner + "/" + deployment.Name - modelName := schemas.Ptr(deployment.Name) var created *int64 - - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, deploymentID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, deploymentID) { - continue - } - - // Extract information from current release if available - if deployment.CurrentRelease != nil { - // Parse created timestamp - if deployment.CurrentRelease.CreatedAt != "" { - createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) - if createdTimestamp > 0 { - created = schemas.Ptr(createdTimestamp) - } + if deployment.CurrentRelease != nil && deployment.CurrentRelease.CreatedAt != "" { + createdTimestamp := ParseReplicateTimestamp(deployment.CurrentRelease.CreatedAt) + if createdTimestamp > 0 { + created = schemas.Ptr(createdTimestamp) } } - bifrostModel := schemas.Model{ - ID: string(providerKey) + "/" + deploymentID, - Name: modelName, - Deployment: modelName, - OwnedBy: schemas.Ptr(deployment.Owner), - Created: created, + for _, result := range pipeline.FilterModel(deploymentID) { + bifrostModel := schemas.Model{ + ID: string(providerKey) + "/" + result.ResolvedID, + Name: schemas.Ptr(deployment.Name), + OwnedBy: schemas.Ptr(deployment.Owner), + Created: created, + } + if result.AliasValue != "" { + bifrostModel.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, bifrostModel) - includedModels[deploymentID] = true } if deploymentsResponse.Next != nil { @@ -63,58 +68,8 @@ func ToBifrostListModelsResponse( } } - // Backfill allowed models that were not in the response - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - if slices.Contains(blacklistedModels, allowedModel) { - continue - } - if !includedModels[allowedModel] { - bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{ - ID: string(providerKey) + "/" + allowedModel, - Name: schemas.Ptr(allowedModel), - }) - } - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) return bifrostResponse } - -// ToReplicateListModelsResponse converts a Bifrost list models response to a Replicate list models response -// This is mainly used for testing and compatibility -func ToReplicateListModelsResponse(response *schemas.BifrostListModelsResponse) *ReplicateModelListResponse { - if response == nil { - return nil - } - - replicateResponse := &ReplicateModelListResponse{ - Results: make([]ReplicateModelResponse, 0, len(response.Data)), - } - - for _, model := range response.Data { - modelID := strings.TrimPrefix(model.ID, string(schemas.Replicate)+"/") - replicateModel := ReplicateModelResponse{ - URL: "https://replicate.com/" + modelID, - Name: modelID, - } - - if model.Description != nil { - replicateModel.Description = model.Description - } - - if model.OwnedBy != nil { - replicateModel.Owner = *model.OwnedBy - } - - replicateResponse.Results = append(replicateResponse.Results, replicateModel) - } - - // Set next page token if available - if response.NextPageToken != "" { - next := response.NextPageToken - replicateResponse.Next = &next - } - - return replicateResponse -} diff --git a/core/providers/replicate/replicate.go b/core/providers/replicate/replicate.go index 373dbb5bbf..aedf38471c 100644 --- a/core/providers/replicate/replicate.go +++ b/core/providers/replicate/replicate.go @@ -87,6 +87,12 @@ const ( pollingInterval = 2 * time.Second ) +// useDeploymentsEndpoint returns whether the key uses the deployments endpoint. +// Nil ReplicateKeyConfig is treated as false (default models/predictions behavior). +func useDeploymentsEndpoint(key schemas.Key) bool { + return key.ReplicateKeyConfig != nil && key.ReplicateKeyConfig.UseDeploymentsEndpoint +} + // createPrediction creates a new prediction on Replicate API // Supports both sync (with Prefer: wait header) and async modes // stripPrefer should be true for streaming requests to exclude the Prefer header @@ -149,7 +155,7 @@ func createPrediction( // Parse response body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, schemas.Replicate) + return nil, nil, latency, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var prediction ReplicatePredictionResponse @@ -204,7 +210,7 @@ func getPrediction( // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, schemas.Replicate) + return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } prediction := &ReplicatePredictionResponse{} @@ -252,9 +258,7 @@ func pollPrediction( case <-pollCtx.Done(): return nil, nil, providerResponseHeaders, providerUtils.NewBifrostOperationError( schemas.ErrProviderRequestTimedOut, - fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds), - schemas.Replicate, - ) + fmt.Errorf("prediction polling timed out after %d seconds", timeoutSeconds)) case <-ticker.C: prediction, rawResponse, providerResponseHeaders, err = getPrediction(pollCtx, client, predictionURL, key, logger, sendBackRawResponse) if err != nil { @@ -277,6 +281,17 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont client := provider.client extraHeaders := provider.networkConfig.ExtraHeaders + if !useDeploymentsEndpoint(key) { + return ToBifrostListModelsResponse( + &ReplicateDeploymentListResponse{}, + providerName, + key.Models, + key.BlacklistedModels, + key.Aliases, + request.Unfiltered, + ), nil + } + // Build deployments URL deploymentsURL := provider.buildRequestURL(ctx, "/v1/deployments", schemas.ListModelsRequest) @@ -335,9 +350,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont if err := sonic.Unmarshal(bodyCopy, &pageResponse); err != nil { return nil, providerUtils.NewBifrostOperationError( "failed to parse deployments response", - err, - schemas.Replicate, - ) + err) } // Append results from this page @@ -362,6 +375,7 @@ func (provider *ReplicateProvider) listDeploymentsByKey(ctx *schemas.BifrostCont providerName, key.Models, key.BlacklistedModels, + key.Aliases, request.Unfiltered, ) @@ -375,11 +389,10 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys } if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } startTime := time.Now() - providerName := provider.GetProviderKey() response, err := providerUtils.HandleMultipleListModelsRequests( ctx, @@ -393,8 +406,6 @@ func (provider *ReplicateProvider) ListModels(ctx *schemas.BifrostContext, keys // Update metadata with total latency latency := time.Since(startTime) - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -406,17 +417,11 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateTextRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -431,7 +436,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.TextCompletionRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // create prediction @@ -480,10 +485,7 @@ func (provider *ReplicateProvider) TextCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostTextCompletionResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -503,11 +505,6 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -519,8 +516,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -532,7 +528,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.TextCompletionStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -556,9 +552,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -587,11 +581,12 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.TextCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -636,7 +631,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -667,11 +662,8 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -705,14 +697,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -727,14 +712,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.TextCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -750,10 +728,7 @@ func (provider *ReplicateProvider) TextCompletionStream(ctx *schemas.BifrostCont nil, // usage - not available in done event finishReason, chunkIndex, - schemas.TextCompletionStreamRequest, - provider.GetProviderKey(), - request.Model, - ) + schemas.TextCompletionStreamRequest) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -781,17 +756,11 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateChatRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -806,7 +775,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k request.Model, provider.customProviderConfig, schemas.ChatCompletionRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // create prediction @@ -855,10 +824,7 @@ func (provider *ReplicateProvider) ChatCompletion(ctx *schemas.BifrostContext, k bifrostResponse := prediction.ToBifrostChatResponse() // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -878,11 +844,6 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -894,8 +855,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -907,7 +867,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont request.Model, provider.customProviderConfig, schemas.ChatCompletionStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -931,9 +891,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -962,11 +920,12 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ChatCompletionStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1011,7 +970,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break @@ -1049,11 +1008,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -1087,14 +1043,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1109,14 +1058,7 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream ended with error"), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - } + fmt.Errorf("stream ended with error")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) enrichedErr := providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) @@ -1142,11 +1084,8 @@ func (provider *ReplicateProvider) ChatCompletionStream(ctx *schemas.BifrostCont }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -1174,17 +1113,11 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } @@ -1199,7 +1132,7 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ResponsesRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // create prediction @@ -1246,9 +1179,6 @@ func (provider *ReplicateProvider) Responses(ctx *schemas.BifrostContext, key sc // Convert to Bifrost response response := prediction.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1266,24 +1196,18 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Build replicate request jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, - func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }, - provider.GetProviderKey()) + func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateResponsesRequest(request) }) if bifrostErr != nil { return nil, bifrostErr } // Enable streaming (using sjson to set field directly, preserving key order) if updatedData, err := providerUtils.SetJSONField(jsonData, "stream", true); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError("failed to set stream field", err) } else { jsonData = updatedData } @@ -1295,7 +1219,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ResponsesStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -1319,9 +1243,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { bifrostErr := providerUtils.NewBifrostOperationError( "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - provider.GetProviderKey(), - ) + fmt.Errorf("prediction response missing stream URL")) return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1360,9 +1282,9 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errors.Is(streamErr, fasthttp.ErrTimeout) || errors.Is(streamErr, context.DeadlineExceeded) { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr, provider.GetProviderKey()), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, streamErr), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Extract provider response headers before status check so error responses also forward them @@ -1389,11 +1311,15 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, // Start streaming in a goroutine go func() { + // Registered first so the post-hook span finalizer runs on every exit + // path — including the empty-reader early return below, which would + // otherwise skip any finalizer declared later in this goroutine. + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.GetProviderKey(), request.Model, schemas.ResponsesStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -1405,10 +1331,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, if reader == nil { bifrostErr := providerUtils.NewBifrostOperationError( - "Provider returned an empty response", - fmt.Errorf("provider returned an empty response"), - provider.GetProviderKey(), - ) + "provider returned an empty response", + fmt.Errorf("provider returned an empty response")) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse), responseChan, provider.logger) return @@ -1455,7 +1379,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn("Error reading stream: %v", readErr) - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr, provider.GetProviderKey()) + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1497,11 +1421,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } if sendBackRawRequest { @@ -1524,10 +1445,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CreatedAt: int(startTime.Unix()), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1556,10 +1474,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1587,10 +1502,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1610,10 +1522,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, Delta: schemas.Ptr(currentEvent.Data), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1639,10 +1548,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, ItemID: schemas.Ptr(itemID), LogProbs: []schemas.ResponsesOutputMessageContentTextLogProb{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1665,10 +1571,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1702,10 +1605,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - ChunkIndex: sequenceNumber, + ChunkIndex: sequenceNumber, }, } providerUtils.ProcessAndSendResponse(ctx, postHookRunner, @@ -1725,11 +1625,8 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, CompletedAt: schemas.Ptr(int(time.Now().Unix())), }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - Latency: time.Since(startTime).Milliseconds(), - ChunkIndex: sequenceNumber, + Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: sequenceNumber, }, } @@ -1762,14 +1659,7 @@ func (provider *ReplicateProvider) ResponsesStream(ctx *schemas.BifrostContext, } bifrostErr := providerUtils.NewBifrostOperationError( errorMsg, - fmt.Errorf("stream error: %s", errorMsg), - provider.GetProviderKey(), - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: request.Model, - RequestType: schemas.ResponsesStreamRequest, - } + fmt.Errorf("stream error: %s", errorMsg)) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { @@ -1830,19 +1720,13 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageGenerationInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1857,7 +1741,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageGenerationRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction with appropriate mode @@ -1909,10 +1793,7 @@ func (provider *ReplicateProvider) ImageGeneration(ctx *schemas.BifrostContext, } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -1931,15 +1812,9 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -1948,8 +1823,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon replicateReq := ToReplicateImageGenerationInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -1961,7 +1835,7 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon request.Model, provider.customProviderConfig, schemas.ImageGenerationStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction prediction, _, _, _, err := createPrediction( @@ -1982,10 +1856,16 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2014,11 +1894,12 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageGenerationStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2065,7 +1946,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) provider.logger.Warn(fmt.Sprintf("Error reading SSE stream: %v", readErr)) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.ImageGenerationStreamRequest, providerName, request.Model, provider.logger) + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2110,11 +1992,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2148,36 +2027,24 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - } + fmt.Errorf("stream ended with error")) // Include accumulated raw responses in error if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2192,11 +2059,8 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon OutputFormat: lastOutputFormat, // Include output format CreatedAt: time.Now().Unix(), ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageGenerationStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2238,17 +2102,13 @@ func (provider *ReplicateProvider) ImageGenerationStream(ctx *schemas.BifrostCon Error: &schemas.ErrorField{ Message: errorMsg, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationStreamRequest, - }, } // Include accumulated raw responses in error if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2265,19 +2125,13 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateImageEditInput(request), nil - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2292,7 +2146,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc request.Model, provider.customProviderConfig, schemas.ImageEditRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction with appropriate mode @@ -2344,10 +2198,7 @@ func (provider *ReplicateProvider) ImageEdit(ctx *schemas.BifrostContext, key sc } // Set extra fields - bifrostResponse.ExtraFields.Provider = schemas.Replicate - bifrostResponse.ExtraFields.RequestType = schemas.ImageEditRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2366,15 +2217,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, return nil, err } - providerName := provider.GetProviderKey() sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - // Convert Bifrost request to Replicate format with streaming enabled jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -2383,8 +2228,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, replicateReq := ToReplicateImageEditInput(request) replicateReq.Stream = schemas.Ptr(true) return replicateReq, nil - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2396,7 +2240,7 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.ImageEditStreamRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction @@ -2418,10 +2262,16 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, // Verify stream URL is available if prediction.URLs == nil || prediction.URLs.Stream == nil || *prediction.URLs.Stream == "" { - return nil, providerUtils.NewBifrostOperationError( - "stream URL not available in prediction response", - fmt.Errorf("prediction response missing stream URL"), - providerName, + return nil, providerUtils.EnrichError( + ctx, + providerUtils.NewBifrostOperationError( + "stream URL not available in prediction response", + fmt.Errorf("prediction response missing stream URL"), + ), + jsonData, + nil, + sendBackRawRequest, + sendBackRawResponse, ) } @@ -2450,11 +2300,12 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.ImageEditStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger) } close(responseChan) }() @@ -2499,18 +2350,9 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, if errors.Is(readErr, context.Canceled) { return } - bifrostErr := providerUtils.NewBifrostOperationError( - "stream read error", - readErr, - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + enrichedErr := providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("stream read error", readErr), jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) - providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) + providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, enrichedErr, responseChan, provider.logger) } break } @@ -2553,11 +2395,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: outputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), }, } @@ -2591,34 +2430,22 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, case "canceled": bifrostErr := providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("stream ended: prediction canceled"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended: prediction canceled")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return case "error": bifrostErr := providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("stream ended with error"), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("stream ended with error")) if sendBackRawResponse && len(rawResponseChunks) > 0 { bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2633,11 +2460,8 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, CreatedAt: time.Now().Unix(), OutputFormat: lastOutputFormat, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ImageEditStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(startTime).Milliseconds(), }, } @@ -2665,18 +2489,12 @@ func (provider *ReplicateProvider) ImageEditStream(ctx *schemas.BifrostContext, bifrostErr := providerUtils.NewBifrostOperationError( "stream error", - fmt.Errorf("%s", errorData.Detail), - providerName, - ) - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditStreamRequest, - } + fmt.Errorf("%s", errorData.Detail)) if sendBackRawResponse { rawResponseChunks = append(rawResponseChunks, ReplicateSSEEvent{Event: eventType, Data: eventData}) bifrostErr.ExtraFields.RawResponse = rawResponseChunks } + bifrostErr = providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse) ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, provider.logger) return @@ -2698,21 +2516,13 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, return nil, err } - deployment, isDeployment := resolveDeploymentModel(request.Model, key) - if isDeployment { - request.Model = deployment - } - - providerName := provider.GetProviderKey() - // Convert Bifrost request to Replicate format jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToReplicateVideoGenerationInput(request) - }, - providerName) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -2724,7 +2534,7 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, request.Model, provider.customProviderConfig, schemas.VideoGenerationRequest, - isDeployment, + useDeploymentsEndpoint(key), ) // Create prediction with appropriate mode @@ -2753,13 +2563,10 @@ func (provider *ReplicateProvider) VideoGeneration(ctx *schemas.BifrostContext, if err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) + bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, schemas.Replicate) // Set extra fields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() - bifrostResponse.ExtraFields.ModelRequested = request.Model bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData) @@ -2779,7 +2586,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -2821,7 +2628,7 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) @@ -2833,12 +2640,10 @@ func (provider *ReplicateProvider) VideoRetrieve(ctx *schemas.BifrostContext, ke bifrostResponse, convertErr := ToBifrostVideoGenerationResponse(&prediction) if convertErr != nil { - return nil, providerUtils.EnrichError(ctx, convertErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, convertErr, nil, body, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResponse.ID, providerName) - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerResponseHeaders if sendBackRawResponse { @@ -2853,9 +2658,8 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if err := providerUtils.CheckOperationAllowed(schemas.Replicate, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve latest status/output first. bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2869,19 +2673,17 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if videoResp.Videos[0].URL != nil { videoUrl = *videoResp.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -2901,9 +2703,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp) @@ -2911,7 +2711,7 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -2925,8 +2725,6 @@ func (provider *ReplicateProvider) VideoDownload(ctx *schemas.BifrostContext, ke } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2982,7 +2780,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s providerName := provider.GetProviderKey() if len(request.File) == 0 { - return nil, providerUtils.NewBifrostOperationError("file content is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file content is required", nil) } // Create multipart form data @@ -3019,22 +2817,22 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s part, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create form file", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create form file", err) } if _, err := part.Write(request.File); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write file content", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write file content", err) } // Add filename field if provided if filename != "" { if err := writer.WriteField("filename", filename); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write filename field", err) } } // Add type field (content type) if err := writer.WriteField("type", contentType); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write type field", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write type field", err) } // Add metadata field if provided @@ -3043,24 +2841,24 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s if len(metadata) > 0 { metadataJSON, err := providerUtils.MarshalSorted(metadata) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal metadata", err) } h := make(textproto.MIMEHeader) h.Set("Content-Disposition", `form-data; name="metadata"`) h.Set("Content-Type", "application/json") metadataPart, err := writer.CreatePart(h) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to create metadata part", err) } if _, err := metadataPart.Write(metadataJSON); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to write metadata", err) } } } } if err := writer.Close(); err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err) } // Create request @@ -3096,7 +2894,7 @@ func (provider *ReplicateProvider) FileUpload(ctx *schemas.BifrostContext, key s body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var replicateResp ReplicateFileResponse @@ -3124,7 +2922,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] // Initialize serial pagination helper (Replicate uses cursor-based pagination) helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger) if err != nil { - return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err) } // Get current key to query @@ -3135,10 +2933,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Object: "list", Data: []schemas.FileObject{}, HasMore: false, - ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, - }, }, nil } @@ -3187,7 +2981,7 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] body, decodeErr := providerUtils.CheckAndDecodeBody(resp) if decodeErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr) } var replicateResp ReplicateFileListResponse @@ -3231,8 +3025,6 @@ func (provider *ReplicateProvider) FileList(ctx *schemas.BifrostContext, keys [] Data: files, HasMore: finalHasMore, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileListRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3249,7 +3041,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3294,7 +3086,7 @@ func (provider *ReplicateProvider) FileRetrieve(ctx *schemas.BifrostContext, key if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3326,7 +3118,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys providerName := provider.GetProviderKey() if request.FileID == "" { - return nil, providerUtils.NewBifrostOperationError("file_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("file_id is required", nil) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -3369,8 +3161,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, @@ -3391,7 +3181,7 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys if err != nil { fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) - lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) continue } @@ -3416,8 +3206,6 @@ func (provider *ReplicateProvider) FileDelete(ctx *schemas.BifrostContext, keys Object: "file", Deleted: true, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.FileDeleteRequest, - Provider: providerName, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerResponseHeaders, }, diff --git a/core/providers/replicate/replicate_test.go b/core/providers/replicate/replicate_test.go index c6f72cfded..a9179f963c 100644 --- a/core/providers/replicate/replicate_test.go +++ b/core/providers/replicate/replicate_test.go @@ -459,187 +459,6 @@ func TestBifrostToReplicateImageGenerationConversion(t *testing.T) { validate func(t *testing.T, result *replicate.ReplicatePredictionRequest) wantErr bool }{ - { - name: "Flux_1_1_Pro_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Flux 1.1 Pro should use ImagePrompt field - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.Image) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_1_1_Pro_Ultra_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro-ultra", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, - { - name: "Flux_Pro_ImagePrompt", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, - { - name: "Flux_Kontext_Pro_InputImage", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-kontext-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Kontext models should use InputImage field - assert.NotNil(t, result.Input.InputImage) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.InputImage) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.Image) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_Kontext_Max_InputImage", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-kontext-max", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.InputImage) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.InputImage) - }, - }, - { - name: "Flux_Dev_Image", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-dev", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Flux Dev should use Image field - assert.NotNil(t, result.Input.Image) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.Image) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.InputImages) - }, - }, - { - name: "Flux_Fill_Pro_Image", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-fill-pro", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - assert.NotNil(t, result.Input.Image) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.Image) - }, - }, - { - name: "Other_Model_InputImages", - input: &schemas.BifrostImageGenerationRequest{ - Model: "stability-ai/sdxl", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input1.jpg", "https://example.com/input2.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Other models should use InputImages array - assert.NotNil(t, result.Input.InputImages) - assert.Len(t, result.Input.InputImages, 2) - assert.Equal(t, "https://example.com/input1.jpg", result.Input.InputImages[0]) - assert.Equal(t, "https://example.com/input2.jpg", result.Input.InputImages[1]) - assert.Nil(t, result.Input.ImagePrompt) - assert.Nil(t, result.Input.InputImage) - assert.Nil(t, result.Input.Image) - }, - }, - { - name: "Model_With_Version", - input: &schemas.BifrostImageGenerationRequest{ - Model: "black-forest-labs/flux-1.1-pro:v1.0", - Input: &schemas.ImageGenerationInput{ - Prompt: prompt, - }, - Params: &schemas.ImageGenerationParameters{ - InputImages: []string{"https://example.com/input.jpg"}, - }, - }, - validate: func(t *testing.T, result *replicate.ReplicatePredictionRequest) { - require.NotNil(t, result) - require.NotNil(t, result.Input) - // Should still match flux-1.1-pro and use ImagePrompt - assert.NotNil(t, result.Input.ImagePrompt) - assert.Equal(t, "https://example.com/input.jpg", *result.Input.ImagePrompt) - }, - }, { name: "AllParameters", input: &schemas.BifrostImageGenerationRequest{ diff --git a/core/providers/replicate/types.go b/core/providers/replicate/types.go index 98f84e613e..3ae88c0095 100644 --- a/core/providers/replicate/types.go +++ b/core/providers/replicate/types.go @@ -313,28 +313,28 @@ type ReplicatePredictionListResponse struct { // ReplicateModelResponse represents a model response type ReplicateModelResponse struct { - URL string `json:"url"` // Model API URL - Owner string `json:"owner"` // Owner username or org name - Name string `json:"name"` // Model name - Description *string `json:"description,omitempty"` // Model description - Visibility string `json:"visibility"` // "public" or "private" - GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL - PaperURL *string `json:"paper_url,omitempty"` // Research paper URL - LicenseURL *string `json:"license_url,omitempty"` // License URL - RunCount *int `json:"run_count,omitempty"` // Number of times run - CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL - DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) - LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details - FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details + URL string `json:"url"` // Model API URL + Owner string `json:"owner"` // Owner username or org name + Name string `json:"name"` // Model name + Description *string `json:"description,omitempty"` // Model description + Visibility string `json:"visibility"` // "public" or "private" + GithubURL *string `json:"github_url,omitempty"` // GitHub repository URL + PaperURL *string `json:"paper_url,omitempty"` // Research paper URL + LicenseURL *string `json:"license_url,omitempty"` // License URL + RunCount *int `json:"run_count,omitempty"` // Number of times run + CoverImageURL *string `json:"cover_image_url,omitempty"` // Cover image URL + DefaultExample *json.RawMessage `json:"default_example,omitempty"` // Default example prediction (json.RawMessage preserves key ordering) + LatestVersion *ReplicateModelVersion `json:"latest_version,omitempty"` // Latest version details + FeaturedVersion *ReplicateModelVersion `json:"featured_version,omitempty"` // Featured version details } // ReplicateModelVersion represents a model version type ReplicateModelVersion struct { - ID string `json:"id"` // Version ID - CreatedAt string `json:"created_at"` // ISO 8601 timestamp - CogVersion *string `json:"cog_version,omitempty"` // Cog version used - OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) - DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID + ID string `json:"id"` // Version ID + CreatedAt string `json:"created_at"` // ISO 8601 timestamp + CogVersion *string `json:"cog_version,omitempty"` // Cog version used + OpenAPISchema json.RawMessage `json:"openapi_schema,omitempty"` // OpenAPI schema for the model (json.RawMessage preserves key ordering) + DockerImageID *string `json:"docker_image_id,omitempty"` // Docker image ID } // ReplicateModelListResponse represents a paginated list of models diff --git a/core/providers/replicate/utils.go b/core/providers/replicate/utils.go index 3279b0a847..1d88337539 100644 --- a/core/providers/replicate/utils.go +++ b/core/providers/replicate/utils.go @@ -31,17 +31,13 @@ func checkForErrorStatus(prediction *ReplicatePredictionResponse) *schemas.Bifro } return providerUtils.NewBifrostOperationError( "prediction failed", - fmt.Errorf("%s", errorMsg), - schemas.Replicate, - ) + fmt.Errorf("%s", errorMsg)) } if prediction.Status == ReplicatePredictionStatusCanceled { return providerUtils.NewBifrostOperationError( "prediction was canceled", - fmt.Errorf("prediction was canceled"), - schemas.Replicate, - ) + fmt.Errorf("prediction was canceled")) } return nil @@ -126,9 +122,9 @@ func listenToReplicateStreamURL( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, schemas.Replicate) + return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Extract provider response headers before status check so error responses also forward them @@ -178,24 +174,12 @@ func isVersionID(s string) bool { return versionIDPattern.MatchString(s) } -// resolveDeploymentModel checks if the model maps to a deployment. -// Returns the resolved model and whether it is a deployment. -func resolveDeploymentModel(model string, key schemas.Key) (string, bool) { - if key.ReplicateKeyConfig == nil || key.ReplicateKeyConfig.Deployments == nil { - return model, false - } - if deployment, ok := key.ReplicateKeyConfig.Deployments[model]; ok && deployment != "" { - return deployment, true - } - return model, false -} - // buildPredictionURL builds the appropriate URL for creating a prediction // Returns the URL for the appropriate prediction endpoint. -func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, isDeployment bool) string { +func buildPredictionURL(ctx *schemas.BifrostContext, baseURL, model string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType, useDeploymentsEndpoint bool) string { var defaultPath string - if isDeployment { + if useDeploymentsEndpoint { defaultPath = "/v1/deployments/" + model + "/predictions" } else if isVersionID(model) { // If model is a version ID, use base predictions endpoint diff --git a/core/providers/replicate/videos.go b/core/providers/replicate/videos.go index b6dadaab55..3a277d067d 100644 --- a/core/providers/replicate/videos.go +++ b/core/providers/replicate/videos.go @@ -87,9 +87,6 @@ func ToBifrostVideoGenerationResponse(prediction *ReplicatePredictionResponse) ( Error: &schemas.ErrorField{ Message: "prediction response is nil", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: schemas.Replicate, - }, } } diff --git a/core/providers/runway/errors.go b/core/providers/runway/errors.go index a64f8ffc60..d9259e825f 100644 --- a/core/providers/runway/errors.go +++ b/core/providers/runway/errors.go @@ -9,7 +9,7 @@ import ( ) // parseRunwayError parses Runway API error responses and converts them to BifrostError. -func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { +func parseRunwayError(resp *fasthttp.Response) *schemas.BifrostError { // Parse as RunwayAPIError var errorResp RunwayAPIError bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp) @@ -34,12 +34,5 @@ func parseRunwayError(resp *fasthttp.Response, meta *providerUtils.RequestMetada bifrostErr.Error.Message = strings.TrimRight(bifrostErr.Error.Message, "\n") } - // Set metadata - if meta != nil { - bifrostErr.ExtraFields.Provider = meta.Provider - bifrostErr.ExtraFields.ModelRequested = meta.Model - bifrostErr.ExtraFields.RequestType = meta.RequestType - } - return bifrostErr } diff --git a/core/providers/runway/runway.go b/core/providers/runway/runway.go index 7bcce918c9..6bb5e31e32 100644 --- a/core/providers/runway/runway.go +++ b/core/providers/runway/runway.go @@ -170,8 +170,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostReq, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToRunwayVideoGenerationRequest(bifrostReq) - }, - provider.GetProviderKey()) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -210,17 +209,14 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -237,10 +233,7 @@ func (provider *RunwayProvider) VideoGeneration(ctx *schemas.BifrostContext, key Object: "video", Status: schemas.VideoStatusQueued, ExtraFields: schemas.BifrostResponseExtraFields{ - Latency: latency.Milliseconds(), - Provider: providerName, - ModelRequested: model, - RequestType: schemas.VideoGenerationRequest, + Latency: latency.Milliseconds(), }, } @@ -287,16 +280,14 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Decode response body body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -314,8 +305,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest if sendBackRawRequest { bifrostResp.ExtraFields.RawRequest = rawRequest @@ -329,7 +318,6 @@ func (provider *RunwayProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // VideoDownload retrieves a video from Runway's API. func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() // Retrieve task status to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ Provider: request.Provider, @@ -343,20 +331,21 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if taskDetails.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", taskDetails.Status), - nil, - providerName, - ) + nil) } if len(taskDetails.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var videoUrl string if taskDetails.Videos[0].URL != nil { videoUrl = *taskDetails.Videos[0].URL } if videoUrl == "" { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + // Download video from Runway's URL req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() @@ -372,14 +361,13 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } // Get content and content type body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), nil, rawErrBody, sendBackRawRequest, sendBackRawResponse) } contentType := string(resp.Header.ContentType()) if contentType == "" { @@ -394,8 +382,6 @@ func (provider *RunwayProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest return bifrostResp, nil } @@ -407,7 +393,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch providerName := provider.GetProviderKey() if request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("task_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("task_id is required", nil) } taskID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName) @@ -439,10 +425,7 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch // Handle error response - Runway returns 204 No Content on success if resp.StatusCode() != fasthttp.StatusNoContent { - return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoDeleteRequest, - }), nil, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseRunwayError(resp), nil, nil, sendBackRawRequest, sendBackRawResponse) } // Build response - Runway returns empty body on 204 @@ -453,8 +436,6 @@ func (provider *RunwayProvider) VideoDelete(ctx *schemas.BifrostContext, key sch } response.ExtraFields.Latency = latency.Milliseconds() - response.ExtraFields.Provider = providerName - response.ExtraFields.RequestType = schemas.VideoDeleteRequest return response, nil } diff --git a/core/providers/runway/videos.go b/core/providers/runway/videos.go index 49b0cec237..809a8a1038 100644 --- a/core/providers/runway/videos.go +++ b/core/providers/runway/videos.go @@ -121,7 +121,7 @@ func ToRunwayVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRe // ToBifrostVideoGenerationResponse converts Runway task details to Bifrost video generation response format. func ToBifrostVideoGenerationResponse(taskDetails *RunwayTaskDetailsResponse) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if taskDetails == nil { - return nil, providerUtils.NewBifrostOperationError("task details is nil", nil, schemas.Runway) + return nil, providerUtils.NewBifrostOperationError("task details is nil", nil) } response := &schemas.BifrostVideoGenerationResponse{ diff --git a/core/providers/sgl/sgl.go b/core/providers/sgl/sgl.go index ce5d3d936a..5b07356851 100644 --- a/core/providers/sgl/sgl.go +++ b/core/providers/sgl/sgl.go @@ -3,7 +3,6 @@ package sgl import ( - "fmt" "strings" "time" @@ -50,11 +49,7 @@ func NewSGLProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*SGL client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger) config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/") - // BaseURL is required for SGLang - if config.NetworkConfig.BaseURL == "" { - return nil, fmt.Errorf("base_url is required for sgl provider") - } - + // BaseURL is optional when keys have sgl_key_config with per-key URLs return &SGLProvider{ logger: logger, client: client, @@ -69,27 +64,40 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { return schemas.SGL } -// ListModels performs a list models request to SGL's API. -func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - return openai.HandleOpenAIListModelsRequest( +// listModelsByKey performs a list models request for a single SGL key, +// resolving the per-key URL so each backend is queried individually. +func (provider *SGLProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return openai.ListModelsByKey( ctx, provider.client, - request, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"), - keys, + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/models"), + key, + request.Unfiltered, provider.networkConfig.ExtraHeaders, - schemas.SGL, + provider.GetProviderKey(), providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), ) } -// TextCompletion is not supported by the SGL provider. +// ListModels performs a list models request to SGL's API. +// Requests are made concurrently per key so that each backend is queried +// with its own URL (from sgl_key_config). +func (provider *SGLProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { + return providerUtils.HandleMultipleListModelsRequests( + ctx, + keys, + request, + provider.listModelsByKey, + ) +} + +// TextCompletion performs a text completion request to the SGL API. func (provider *SGLProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { return openai.HandleOpenAITextCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -109,7 +117,7 @@ func (provider *SGLProvider) TextCompletionStream(ctx *schemas.BifrostContext, p return openai.HandleOpenAITextCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/completions", + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -129,7 +137,7 @@ func (provider *SGLProvider) ChatCompletion(ctx *schemas.BifrostContext, key sch return openai.HandleOpenAIChatCompletionRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, key, provider.networkConfig.ExtraHeaders, @@ -151,7 +159,7 @@ func (provider *SGLProvider) ChatCompletionStream(ctx *schemas.BifrostContext, p return openai.HandleOpenAIChatCompletionStreaming( ctx, provider.client, - provider.networkConfig.BaseURL+"/v1/chat/completions", + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"), request, nil, provider.networkConfig.ExtraHeaders, @@ -176,9 +184,6 @@ func (provider *SGLProvider) Responses(ctx *schemas.BifrostContext, key schemas. } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -194,12 +199,12 @@ func (provider *SGLProvider) ResponsesStream(ctx *schemas.BifrostContext, postHo ) } -// Embedding is not supported by the SGL provider. +// Embedding performs an embedding request to the SGL API. func (provider *SGLProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { return openai.HandleOpenAIEmbeddingRequest( ctx, provider.client, - provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), + key.SGLKeyConfig.URL.GetValue()+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"), request, key, provider.networkConfig.ExtraHeaders, diff --git a/core/providers/sgl/sgl_test.go b/core/providers/sgl/sgl_test.go index 11447f58b4..20236182fc 100644 --- a/core/providers/sgl/sgl_test.go +++ b/core/providers/sgl/sgl_test.go @@ -29,22 +29,22 @@ func TestSGL(t *testing.T) { TextModel: "qwen/qwen2.5-0.5b-instruct", EmbeddingModel: "Alibaba-NLP/gte-Qwen2-1.5B-instruct", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - MultipleImages: true, - CompleteEnd2End: true, - Embedding: true, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + MultipleImages: true, + CompleteEnd2End: true, + Embedding: true, + ListModels: true, }, } diff --git a/core/providers/utils/decompression_test.go b/core/providers/utils/decompression_test.go index aab4d1d6e8..16ed30d608 100644 --- a/core/providers/utils/decompression_test.go +++ b/core/providers/utils/decompression_test.go @@ -496,10 +496,11 @@ func TestSafeReset(t *testing.T) { if ok { t.Fatal("expected false for panicking reset") } - }) - - t.Run("panic_nil", func(t *testing.T) { - ok := safeReset(func() error { panic(nil) }) + t.Run("panic_nonnnil", func(t *testing.T) { + ok := safeReset(func() error { panic("") }) + if ok { + t.Fatal("expected false for nil panic") + } if ok { t.Fatal("expected false for nil panic") } diff --git a/core/providers/utils/images.go b/core/providers/utils/images.go new file mode 100644 index 0000000000..12bc01c8ba --- /dev/null +++ b/core/providers/utils/images.go @@ -0,0 +1,50 @@ +package utils + +import ( + "strconv" + "strings" +) + +// ConvertSizeToAspectRatioAndResolution converts a standard size string (e.g., "1024x1024") +// to an aspect ratio and image size tier. +// aspectRatio is one of "1:1", "3:4", "4:3", "9:16", "16:9" (empty if unrecognised). +// imageSize is one of "1K", "2K", "4K" (empty if out of range). +func ConvertSizeToAspectRatioAndResolution(size string) (aspectRatio, imageSize string) { + parts := strings.Split(size, "x") + if len(parts) != 2 { + return "", "" + } + + width, err1 := strconv.Atoi(parts[0]) + height, err2 := strconv.Atoi(parts[1]) + if err1 != nil || err2 != nil { + return "", "" + } + + if width <= 0 || height <= 0 { + return "", "" + } + + if width <= 1024 && height <= 1024 { + imageSize = "1K" + } else if width <= 2048 && height <= 2048 { + imageSize = "2K" + } else if width <= 4096 && height <= 4096 { + imageSize = "4K" + } + + ratio := float64(width) / float64(height) + if ratio >= 0.99 && ratio <= 1.01 { + aspectRatio = "1:1" + } else if ratio >= 0.74 && ratio <= 0.76 { + aspectRatio = "3:4" + } else if ratio >= 1.32 && ratio <= 1.34 { + aspectRatio = "4:3" + } else if ratio >= 0.56 && ratio <= 0.57 { + aspectRatio = "9:16" + } else if ratio >= 1.77 && ratio <= 1.78 { + aspectRatio = "16:9" + } + + return aspectRatio, imageSize +} diff --git a/core/providers/utils/large_response.go b/core/providers/utils/large_response.go index a7e0e7bf36..e62d375c9a 100644 --- a/core/providers/utils/large_response.go +++ b/core/providers/utils/large_response.go @@ -116,7 +116,6 @@ func MaterializeStreamErrorBody(ctx *schemas.BifrostContext, resp *fasthttp.Resp func FinalizeResponseWithLargeDetection( ctx *schemas.BifrostContext, resp *fasthttp.Response, - providerName schemas.ModelProvider, logger schemas.Logger, ) ([]byte, bool, *schemas.BifrostError) { responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64) @@ -125,7 +124,7 @@ func FinalizeResponseWithLargeDetection( if responseThreshold <= 0 { body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } // Copy body before caller releases resp return append([]byte(nil), body...), false, nil @@ -142,14 +141,14 @@ func FinalizeResponseWithLargeDetection( } bodyBytes, readErr := io.ReadAll(reader) if readErr != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } return bodyBytes, false, nil } // No stream — buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -169,7 +168,7 @@ func FinalizeResponseWithLargeDetection( bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, responseThreshold+1)) if readErr != nil { releaseGzip() - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } if int64(len(bodyBytes)) <= responseThreshold { releaseGzip() @@ -195,7 +194,7 @@ func FinalizeResponseWithLargeDetection( // No stream — buffered fallback body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -206,11 +205,11 @@ func FinalizeResponseWithLargeDetection( if bodyStream == nil { // No stream available — fall back to buffered read if logger != nil { - logger.Warn("large-response fallback to buffered path: provider=%s content_length=%d threshold=%d body_stream_nil=true", providerName, contentLength, responseThreshold) + logger.Warn("large-response fallback to buffered path: content_length=%d threshold=%d body_stream_nil=true", contentLength, responseThreshold) } body, err := CheckAndDecodeBody(resp) if err != nil { - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } return append([]byte(nil), body...), false, nil } @@ -232,7 +231,7 @@ func FinalizeResponseWithLargeDetection( if wasGzip { ReleaseGzipReader(gz) } - return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr, providerName) + return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr) } prefetchBuf = prefetchBuf[:n] diff --git a/core/providers/utils/make_request_test.go b/core/providers/utils/make_request_test.go index ce1610d7bb..ec2bf771bc 100644 --- a/core/providers/utils/make_request_test.go +++ b/core/providers/utils/make_request_test.go @@ -295,7 +295,7 @@ func TestMakeRequestWithContext_ConcurrentRequestsWithCancellation(t *testing.T) } func TestNewBifrostTimeoutError(t *testing.T) { - err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded, "openai") + err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded) if !err.IsBifrostError { t.Fatal("expected IsBifrostError to be true") diff --git a/core/providers/utils/models.go b/core/providers/utils/models.go new file mode 100644 index 0000000000..f1b3d0351b --- /dev/null +++ b/core/providers/utils/models.go @@ -0,0 +1,356 @@ +// Package utils — list_models.go +// Centralised pipeline for filtering and backfilling models in ListModels responses. +// +// Every provider's ToBifrostListModelsResponse follows the same logical steps: +// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID) +// 2. Filter (allowlist + blacklist check on the resolved name) +// 3. Backfill entries that were not returned by the API but should appear in output +// +// Providers plug in custom MatchFns to extend the default matching behaviour. +// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns. +package utils + +import ( + "sort" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// ToDisplayName converts a raw model ID or alias key into a human-readable display name. +// Splits on "-" or "_", title-cases each word, and joins with spaces. +// +// "gemini-pro" → "Gemini Pro" +// "claude_3_opus" → "Claude 3 Opus" +// "gpt-4-turbo" → "Gpt 4 Turbo" +func ToDisplayName(id string) string { + caser := cases.Title(language.English) + parts := strings.FieldsFunc(id, func(r rune) bool { + return r == '-' || r == '_' + }) + if len(parts) == 0 { + return "" + } + for i, part := range parts { + if part != "" { + parts[i] = caser.String(strings.ToLower(part)) + } + } + return strings.Join(parts, " ") +} + +// MatchFn reports whether two model ID strings should be treated as equivalent. +// Functions are applied in order during every comparison — the first one that +// returns true short-circuits the rest. +// +// Example built-in fns (see DefaultMatchFns): +// +// exactMatch("gpt-4", "gpt-4") → true +// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true +type MatchFn func(a, b string) bool + +// DefaultMatchFns returns the standard matching functions used by most providers. +// Currently only performs case-insensitive exact matching. +// +// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet") +// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping. +// It can be appended here if fuzzy base-model matching is ever needed globally. +func DefaultMatchFns() []MatchFn { + return []MatchFn{ + func(a, b string) bool { return strings.EqualFold(a, b) }, + } +} + +// matches reports whether a and b are considered equal by any of the provided fns. +// Returns true on the first fn that returns true. +func matches(a, b string, fns []MatchFn) bool { + for _, fn := range fns { + if fn(a, b) { + return true + } + } + return false +} + +// FilterResult is the outcome of running Pipeline.FilterModel for a single model +// from the provider's API response. Each returned result represents one alias +// entry (or the raw model ID when no alias matched) that passed all filters. +type FilterResult struct { + // ResolvedID is the user-facing model name to use as the ID suffix. + // If the model matched an alias VALUE, this is the alias KEY. + // Otherwise this is the original model ID from the API response. + // + // Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"} + // → ResolvedID = "my-gpt4" + // Example: API returns "gpt-3.5-turbo", no alias match + // → ResolvedID = "gpt-3.5-turbo" + ResolvedID string + + // AliasValue is the provider-specific model ID when the model was matched + // via an alias. Set as the model.Alias field so callers know the underlying ID. + // Empty when the model was matched directly (no alias involved). + // + // Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched + // → AliasValue = "gpt-4-turbo" + AliasValue string +} + +// Pipeline holds all the context needed to filter and backfill models in a +// single ListModels response. Construct one per ToBifrostListModelsResponse call +// and use its methods instead of passing params + matchFns to every function. +// +// pipeline := &providerUtils.ListModelsPipeline{ +// AllowedModels: key.Models, +// BlacklistedModels: key.BlacklistedModels, +// Aliases: key.Aliases, +// Unfiltered: request.Unfiltered, +// ProviderKey: schemas.OpenAI, +// MatchFns: providerUtils.DefaultMatchFns(), +// } +// if pipeline.ShouldEarlyExit() { return empty } +// result := pipeline.FilterModel(model.ID) +// pipeline.BackfillModels(included) +type ListModelsPipeline struct { + AllowedModels schemas.WhiteList + BlacklistedModels schemas.BlackList + // Aliases maps user-facing alias keys to provider-specific model IDs. + // e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"} + Aliases map[string]string + Unfiltered bool + ProviderKey schemas.ModelProvider + // MatchFns is the ordered list of equivalence functions used for every + // model ID comparison. Use DefaultMatchFns() for standard behaviour; + // providers may append additional fns (e.g. Bedrock's region-prefix remover). + MatchFns []MatchFn +} + +// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately +// return an empty response without processing any models. +// +// Returns true when: +// - not unfiltered AND allowlist is empty AND no aliases configured +// (there is nothing to match against — all models would be filtered out anyway) +// - not unfiltered AND blacklist blocks everything +// +// Note: allowlist empty + aliases present → do NOT early exit. +// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels). +func (p *ListModelsPipeline) ShouldEarlyExit() bool { + if p.Unfiltered { + return false + } + if p.BlacklistedModels.IsBlockAll() { + return true + } + if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 { + return true + } + return false +} + +// aliasMatch holds a single alias key/value pair returned by resolveModelID. +type aliasMatch struct { + key string + value string +} + +// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns, +// plus the raw model ID itself as an additional entry so both the alias key and the original model +// name appear in the list-models output. +// Results are sorted by alias key (case-insensitive) for deterministic ordering. +// +// If one or more aliases match → returns one aliasMatch per matching alias key, plus the raw ID. +// +// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"} +// → [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}] +// +// If no alias matches → returns a single entry with the original model ID and no alias value. +// +// Example: modelID="gpt-3.5-turbo", no alias match +// → [{key:"gpt-3.5-turbo", value:""}] +func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch { + var candidates []aliasMatch + for aliasKey, providerID := range p.Aliases { + if matches(modelID, providerID, p.MatchFns) { + candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID}) + } + } + if len(candidates) == 0 { + return []aliasMatch{{key: modelID, value: ""}} + } + // Also include the raw model ID so both the alias key and the original name appear in output. + candidates = append(candidates, aliasMatch{key: modelID, value: ""}) + sort.Slice(candidates, func(i, j int) bool { + return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key) + }) + return candidates +} + +// FilterModel applies the full filter pipeline for a single model from the API response. +// +// Steps: +// 1. Resolve name — check alias VALUES for a match (uses MatchFns). +// If matched: resolvedName = alias KEY, aliasValue = provider ID. +// If not matched: resolvedName = original modelID, aliasValue = "". +// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard): +// Skip if resolvedName is not in AllowedModels. +// 3. Blacklist check (always): +// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything. +// 4. Return one FilterResult per passing candidate. +// +// An empty slice means the model should be skipped entirely. +// When multiple aliases map to the same provider model ID, each alias that passes +// the filters produces its own FilterResult entry. +// +// Examples: +// +// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// FilterModel("gpt-3.5") → [] (not in allowlist) +// +// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[] +// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt-4-turbo", AliasValue:""}, +// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"}, +// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}] +// +// allowedModels=["gpt-3.5"], aliases={}, blacklist=[] +// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}] +// FilterModel("gpt-4") → [] +func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult { + // Step 1: resolve name — collect all alias matches (or the raw ID if none match). + candidates := p.resolveModelID(modelID) + + var results []FilterResult + for _, candidate := range candidates { + resolvedName := candidate.key + + // Step 2: allowlist check. + // IsRestricted() is true for both an explicit list AND an empty list (deny-all). + // Only a wildcard allowlist marker bypasses this check (pass-through). + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + allowed := false + for _, entry := range p.AllowedModels { + if matches(resolvedName, entry, p.MatchFns) { + allowed = true + break + } + } + if !allowed { + continue + } + } + + // Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases. + if !p.Unfiltered { + blacklisted := false + for _, entry := range p.BlacklistedModels { + if matches(resolvedName, entry, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + } + + results = append(results, FilterResult{ + ResolvedID: resolvedName, + AliasValue: candidate.value, + }) + } + return results +} + +// BackfillModels adds model entries that were configured by the caller but not +// returned by the provider's API response (or not matched during filtering). +// +// The `included` map tracks model IDs (lowercased) already added during the +// filter pass, used to avoid duplicates. +// +// Two cases depending on whether the allowlist is restricted: +// +// Case A — allowlist restricted (caller specified explicit model names): +// +// Add each allowlist entry that is not yet in `included`, skip if blacklisted. +// If the entry has an alias mapping (aliases[entry] exists), set Alias to the +// provider-specific ID so callers can route to the right model. +// +// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"} +// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"} +// +// Case B — allowlist wildcard (*) only: +// +// We don't know all model names (no explicit list), so we only backfill entries +// that were explicitly configured via aliases and not yet matched from the API. +// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard. +// +// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included +// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"} +// +// Blacklist always wins — nothing blacklisted is added in either case. +func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model { + var result []schemas.Model + + if !p.Unfiltered && p.AllowedModels.IsRestricted() { + // Case A: backfill explicit allowlist entries not yet matched. + for _, entry := range p.AllowedModels { + if included[strings.ToLower(entry)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(entry, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + m := schemas.Model{ + ID: string(p.ProviderKey) + "/" + entry, + Name: schemas.Ptr(ToDisplayName(entry)), + } + // If this allowlist entry has an alias, surface the provider-specific ID. + for aliasKey, providerID := range p.Aliases { + if matches(entry, aliasKey, p.MatchFns) { + m.Alias = schemas.Ptr(providerID) + break + } + } + result = append(result, m) + } + return result + } + + // Case B: wildcard allowlist — backfill only explicitly configured aliases. + if !p.Unfiltered && len(p.Aliases) > 0 { + for aliasKey, providerID := range p.Aliases { + if included[strings.ToLower(aliasKey)] { + continue + } + // Blacklist check. + blacklisted := false + for _, bl := range p.BlacklistedModels { + if matches(aliasKey, bl, p.MatchFns) { + blacklisted = true + break + } + } + if blacklisted { + continue + } + result = append(result, schemas.Model{ + ID: string(p.ProviderKey) + "/" + aliasKey, + Name: schemas.Ptr(ToDisplayName(aliasKey)), + Alias: schemas.Ptr(providerID), + }) + } + } + + return result +} diff --git a/core/providers/utils/stream.go b/core/providers/utils/stream.go index 6b3ea417b4..1cdd011602 100644 --- a/core/providers/utils/stream.go +++ b/core/providers/utils/stream.go @@ -1,6 +1,8 @@ package utils import ( + "context" + schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -20,7 +22,12 @@ import ( // // If the source channel is closed immediately (empty stream), it returns a // nil channel with nil error. drainDone is already closed. +// +// The ctx argument cancels the background forwarding goroutine if the consumer +// abandons the returned wrapped channel. On ctx.Done the goroutine drains the +// source stream so the upstream provider's blocked send can exit cleanly. func CheckFirstStreamChunkForError( + ctx context.Context, stream chan *schemas.BifrostStreamChunk, ) (chan *schemas.BifrostStreamChunk, <-chan struct{}, *schemas.BifrostError) { firstChunk, ok := <-stream @@ -53,7 +60,15 @@ func CheckFirstStreamChunkForError( defer close(done) defer close(wrapped) for chunk := range stream { - wrapped <- chunk + select { + case wrapped <- chunk: + case <-ctx.Done(): + // Consumer abandoned the wrapped channel. Drain the source so the + // provider's blocked send unblocks and its goroutine can exit. + for range stream { + } + return + } } }() return wrapped, done, nil diff --git a/core/providers/utils/stream_test.go b/core/providers/utils/stream_test.go index 45c88853fa..7e843fa4fd 100644 --- a/core/providers/utils/stream_test.go +++ b/core/providers/utils/stream_test.go @@ -1,7 +1,9 @@ package utils import ( + "context" "testing" + "time" schemas "github.com/maximhq/bifrost/core/schemas" ) @@ -18,7 +20,7 @@ func TestCheckFirstStreamChunk_ErrorInFirstChunk(t *testing.T) { } close(stream) - _, drainDone, err := CheckFirstStreamChunkForError(stream) + _, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream) if err == nil { t.Fatal("expected error, got nil") } @@ -47,7 +49,7 @@ func TestCheckFirstStreamChunk_ValidFirstChunk(t *testing.T) { stream <- chunk2 close(stream) - wrapped, _, err := CheckFirstStreamChunkForError(stream) + wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -75,7 +77,7 @@ func TestCheckFirstStreamChunk_EmptyStream(t *testing.T) { stream := make(chan *schemas.BifrostStreamChunk) close(stream) - wrapped, drainDone, err := CheckFirstStreamChunkForError(stream) + wrapped, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -110,7 +112,7 @@ func TestCheckFirstStreamChunk_ErrorInSecondChunk(t *testing.T) { close(stream) // Should NOT return error — only first chunk matters for retry - wrapped, _, err := CheckFirstStreamChunkForError(stream) + wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -149,7 +151,7 @@ func TestCheckFirstStreamChunk_ErrorDrainsSource(t *testing.T) { } close(stream) - _, drainDone, err := CheckFirstStreamChunkForError(stream) + _, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream) if err == nil { t.Fatal("expected error, got nil") } @@ -176,7 +178,7 @@ func TestCheckFirstStreamChunk_ErrorWithEmptyMessage(t *testing.T) { } close(stream) - wrapped, _, err := CheckFirstStreamChunkForError(stream) + wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream) if err != nil { t.Fatalf("unexpected error for empty message: %v", err) } @@ -184,6 +186,49 @@ func TestCheckFirstStreamChunk_ErrorWithEmptyMessage(t *testing.T) { <-wrapped } +func TestCheckFirstStreamChunk_CtxCancelUnblocksWrapper(t *testing.T) { + // Source with cap=1 so wrapped also has cap=1. wrapped is left full by + // the re-injected first chunk, which makes the forwarder goroutine block + // on its next send — the exact leak condition this test guards against. + src := make(chan *schemas.BifrostStreamChunk, 1) + src <- &schemas.BifrostStreamChunk{ + BifrostChatResponse: &schemas.BifrostChatResponse{ID: "1"}, + } + + ctx, cancel := context.WithCancel(context.Background()) + + wrapped, drainDone, err := CheckFirstStreamChunkForError(ctx, src) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wrapped == nil { + t.Fatal("expected wrapped channel, got nil") + } + + // Push a second chunk; forwarder will read it from src and then block + // trying to send into the full wrapped channel (we intentionally never + // read from wrapped). + src <- &schemas.BifrostStreamChunk{ + BifrostChatResponse: &schemas.BifrostChatResponse{ID: "2"}, + } + + // Cancel — forwarder must stop trying to send to wrapped and drain src. + cancel() + + // Simulate the upstream producer still emitting, then closing. The + // drain loop should consume these and terminate. + src <- &schemas.BifrostStreamChunk{ + BifrostChatResponse: &schemas.BifrostChatResponse{ID: "3"}, + } + close(src) + + select { + case <-drainDone: + case <-time.After(time.Second): + t.Fatal("drainDone did not close after ctx cancel; forwarder goroutine leaked") + } +} + func TestCheckFirstStreamChunk_CodeOnlyError(t *testing.T) { // Error with code but no message should be treated as an error stream := make(chan *schemas.BifrostStreamChunk, 2) @@ -196,7 +241,7 @@ func TestCheckFirstStreamChunk_CodeOnlyError(t *testing.T) { } close(stream) - _, drainDone, err := CheckFirstStreamChunkForError(stream) + _, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream) if err == nil { t.Fatal("expected error for code-only error, got nil") } diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index c1b7259375..748c83210b 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -178,12 +178,12 @@ func MakeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *f } // Check for timeout errors first before checking net.OpError to avoid misclassification if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check if error implements net.Error and has Timeout() == true var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { - return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, ""), noop + return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError @@ -1043,7 +1043,7 @@ func MergeExtraParamsIntoJSON(jsonBody []byte, extraParams map[string]interface{ } // CheckContextAndGetRequestBody checks if the raw request body should be used, and returns it if it exists. -func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter, providerType schemas.ModelProvider) ([]byte, *schemas.BifrostError) { +func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter) ([]byte, *schemas.BifrostError) { if IsLargePayloadPassthroughEnabled(ctx) { return nil, nil } @@ -1052,15 +1052,15 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette if !ok { convertedBody, err := requestConverter() if err != nil { - return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err) } if convertedBody == nil { - return nil, NewBifrostOperationError("request body is not provided", nil, providerType) + return nil, NewBifrostOperationError("request body is not provided", nil) } jsonBody, err := MarshalSortedIndent(convertedBody, "", " ") if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { @@ -1070,7 +1070,7 @@ func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGette // tool schemas and other order-sensitive JSON structures. jsonBody, err = MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { - return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerType) + return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } @@ -1367,10 +1367,6 @@ func NewUnsupportedOperationError(requestType schemas.RequestType, providerName Message: fmt.Sprintf("%s is not supported by %s provider", requestType, providerName), Code: schemas.Ptr("unsupported_operation"), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: requestType, - }, } } @@ -1593,37 +1589,31 @@ func ParseJSONL(data []byte, parseLine func(line []byte) error) JSONLParseResult // NewConfigurationError creates a standardized error for configuration errors. // This helper reduces code duplication across providers that have configuration errors. -func NewConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewConfigurationError(message string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: message, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostOperationError creates a standardized error for bifrost operation errors. // This helper reduces code duplication across providers that have bifrost operation errors. -func NewBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostOperationError(message string, err error) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: message, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewBifrostTimeoutError creates a standardized error for provider request timeout errors. // Sets StatusCode to 504 (Gateway Timeout) and Error.Type to RequestTimedOut, // consistent with HandleStreamTimeout for streaming requests. -func NewBifrostTimeoutError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { +func NewBifrostTimeoutError(message string, err error) *schemas.BifrostError { statusCode := 504 errorType := schemas.RequestTimedOut return &schemas.BifrostError{ @@ -1634,15 +1624,12 @@ func NewBifrostTimeoutError(message string, err error, providerType schemas.Mode Type: &errorType, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } // NewProviderAPIError creates a standardized error for provider API errors. // This helper reduces code duplication across providers that have provider API errors. -func NewProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { +func NewProviderAPIError(message string, err error, statusCode int, errorType *string, eventID *string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -1653,61 +1640,43 @@ func NewProviderAPIError(message string, err error, statusCode int, providerType Error: err, Type: errorType, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerType, - }, } } -// RequestMetadata contains metadata about a request for error reporting. -// This struct is used to pass request context to parseError functions. -type RequestMetadata struct { - Provider schemas.ModelProvider - Model string - RequestType schemas.RequestType -} - -// ShouldSendBackRawRequest checks if the raw request should be captured. -// Context overrides are intentionally restricted to asymmetric behavior: a context value can only -// promote false→true and will not override a true config to false, avoiding accidental suppression. -// Both full send-back mode and logging-only mode (store_raw_request_response) set -// BifrostContextKeySendBackRawRequest=true in the request context so a single flag is checked here. -// In logging-only mode the payload is stripped before the response reaches the client. +// ShouldSendBackRawRequest checks if raw request bytes should be captured. +// bifrost.go always writes BifrostContextKeyCaptureRawRequest before provider dispatch, +// combining provider config, per-request overrides, and store_raw_request_response. +// The default parameter is a fallback for callers outside the normal bifrost dispatch path. func ShouldSendBackRawRequest(ctx context.Context, defaultSendBackRawRequest bool) bool { - if sendBackRawRequest, ok := ctx.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok && sendBackRawRequest { - return sendBackRawRequest + if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawRequest).(bool); ok { + return capture } return defaultSendBackRawRequest } -// ShouldSendBackRawResponse checks if the raw response should be captured. -// Context overrides are intentionally restricted to asymmetric behavior: a context value can only -// promote false→true and will not override a true config to false, avoiding accidental suppression. -// Both full send-back mode and logging-only mode (store_raw_request_response) set -// BifrostContextKeySendBackRawResponse=true in the request context so a single flag is checked here. -// In logging-only mode the payload is stripped before the response reaches the client. +// ShouldSendBackRawResponse checks if raw response bytes should be captured. +// bifrost.go always writes BifrostContextKeyCaptureRawResponse before provider dispatch, +// combining provider config, per-request overrides, and store_raw_request_response. +// The default parameter is a fallback for callers outside the normal bifrost dispatch path. func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse bool) bool { - if sendBackRawResponse, ok := ctx.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok && sendBackRawResponse { - return sendBackRawResponse + if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawResponse).(bool); ok { + return capture } return defaultSendBackRawResponse } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. -func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 0, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 0, + Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: firstChunk, } @@ -1715,20 +1684,17 @@ func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event -func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, provider schemas.ModelProvider, model string, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { +func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: provider, - ModelRequested: model, - ChunkIndex: 1, - Latency: time.Since(startTime).Milliseconds(), + ChunkIndex: 1, + Latency: time.Since(startTime).Milliseconds(), }, } - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: chunk, } @@ -1736,13 +1702,14 @@ func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunn } // BuildClientStreamChunk constructs a BifrostStreamChunk from post-hook results. -// It never mutates the shared processedResponse or processedError objects — when in -// logging-only mode (BifrostContextKeyRawRequestResponseForLogging) it shallow-copies -// each inner response struct and the BifrostError, nils only the raw fields on those -// copies, and returns them as the outgoing chunk. This is safe for concurrent PostLLMHook -// goroutines that still hold references to the originals. +// It never mutates the shared processedResponse or processedError objects — when raw fields +// need to be stripped (captured for storage but not for send-back), it shallow-copies each +// inner response struct and nils only the appropriate per-side field on those copies. +// This is safe for concurrent PostLLMHook goroutines that still hold references to the originals. func BuildClientStreamChunk(ctx context.Context, processedResponse *schemas.BifrostResponse, processedError *schemas.BifrostError) *schemas.BifrostStreamChunk { - dropRaw, _ := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool) + dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) + dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) + drop := dropReq || dropResp streamResponse := &schemas.BifrostStreamChunk{} if processedResponse != nil { streamResponse.BifrostTextCompletionResponse = processedResponse.TextCompletionResponse @@ -1753,51 +1720,79 @@ func BuildClientStreamChunk(ctx context.Context, processedResponse *schemas.Bifr streamResponse.BifrostImageGenerationStreamResponse = processedResponse.ImageGenerationStreamResponse // Strip raw fields from client-facing copies without mutating the shared objects // that PostLLMHook goroutines may still be reading. - if dropRaw { + if drop { if streamResponse.BifrostTextCompletionResponse != nil { cp := *streamResponse.BifrostTextCompletionResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostTextCompletionResponse = &cp } if streamResponse.BifrostChatResponse != nil { cp := *streamResponse.BifrostChatResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostChatResponse = &cp } if streamResponse.BifrostResponsesStreamResponse != nil { cp := *streamResponse.BifrostResponsesStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostResponsesStreamResponse = &cp } if streamResponse.BifrostSpeechStreamResponse != nil { cp := *streamResponse.BifrostSpeechStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostSpeechStreamResponse = &cp } if streamResponse.BifrostTranscriptionStreamResponse != nil { cp := *streamResponse.BifrostTranscriptionStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostTranscriptionStreamResponse = &cp } if streamResponse.BifrostImageGenerationStreamResponse != nil { cp := *streamResponse.BifrostImageGenerationStreamResponse - cp.ExtraFields.RawRequest = nil - cp.ExtraFields.RawResponse = nil + if dropReq { + cp.ExtraFields.RawRequest = nil + } + if dropResp { + cp.ExtraFields.RawResponse = nil + } streamResponse.BifrostImageGenerationStreamResponse = &cp } } } if processedError != nil { - if dropRaw { + if drop { // Strip raw fields from a client-facing copy without mutating the shared error object. errCopy := *processedError - errCopy.ExtraFields.RawRequest = nil - errCopy.ExtraFields.RawResponse = nil + if dropReq { + errCopy.ExtraFields.RawRequest = nil + } + if dropResp { + errCopy.ExtraFields.RawResponse = nil + } streamResponse.BifrostError = &errCopy } else { streamResponse.BifrostError = processedError @@ -1893,6 +1888,41 @@ func ProcessAndSendBifrostError( } } +// EnsureStreamFinalizerCalled invokes the post-hook span finalizer registered +// on ctx, if any. Designed to be deferred as the last line of defence in a +// provider's streaming goroutine (next to SetupStreamCancellation's cleanup): +// +// defer providerUtils.EnsureStreamFinalizerCalled(ctx) +// +// On a normal stream end the finalizer is already invoked when the final chunk +// is processed (via completeDeferredSpan). The registration wraps the closure +// in sync.Once, so this safety-net call is a noop in that case. It only does +// real work when the streaming goroutine exits without reaching the final-chunk +// path — e.g. a panic mid-stream — which would otherwise leak the plugin +// pipeline back-reference held by the finalizer closure. +// +// Panics inside the finalizer are recovered and logged so they never mask an +// in-flight panic that triggered the defer. +func EnsureStreamFinalizerCalled(ctx context.Context) { + // Install the recover first so any panic — including one triggered by + // accessing ctx itself — is caught. This matters because this helper is + // called from `defer`, so a panic here would mask the in-flight panic + // that invoked the defer. + defer func() { + if r := recover(); r != nil { + getLogger().Debug("recovered panic in deferred stream finalizer: %v", r) + } + }() + if ctx == nil { + return + } + finalizer, ok := ctx.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)) + if !ok || finalizer == nil { + return + } + finalizer(ctx) +} + // SetupStreamCancellation spawns a goroutine that closes the body stream when // the context is cancelled or deadline exceeded, unblocking any blocked Read/Scan operations. // Returns a cleanup function that MUST be called when streaming is done to @@ -2015,9 +2045,6 @@ func HandleStreamCancellation( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2033,11 +2060,6 @@ func HandleStreamCancellation( Message: "Request cancelled: client disconnected", Type: schemas.Ptr(schemas.RequestCancelled), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2056,9 +2078,6 @@ func HandleStreamTimeout( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, - provider schemas.ModelProvider, - model string, - requestType schemas.RequestType, logger schemas.Logger, ) { // Check if already handled (StreamEndIndicator already set) @@ -2074,11 +2093,6 @@ func HandleStreamTimeout( Message: "Request timed out: deadline exceeded", Type: schemas.Ptr(schemas.RequestTimedOut), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: provider, - ModelRequested: model, - RequestType: requestType, - }, } // Send through PostHook chain - this updates the log to "error" status @@ -2094,25 +2108,16 @@ func ProcessAndSendError( postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStreamChunk, - requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, logger schemas.Logger, ) { // Send scanner error through channel - bifrostError := - &schemas.BifrostError{ - IsBifrostError: true, - Error: &schemas.ErrorField{ - Message: fmt.Sprintf("Error reading stream: %v", err), - Error: err, - }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - }, - } + bifrostError := &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: fmt.Sprintf("Error reading stream: %v", err), + Error: err, + }, + } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) if HandleStreamControlSkip(processedError) { @@ -2144,8 +2149,6 @@ func CreateBifrostTextCompletionChunkResponse( finishReason *string, currentChunkIndex int, requestType schemas.RequestType, - providerName schemas.ModelProvider, - model string, ) *schemas.BifrostTextCompletionResponse { response := &schemas.BifrostTextCompletionResponse{ ID: id, @@ -2158,10 +2161,7 @@ func CreateBifrostTextCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2173,14 +2173,15 @@ func CreateBifrostChatCompletionChunkResponse( usage *schemas.BifrostLLMUsage, finishReason *string, currentChunkIndex int, - requestType schemas.RequestType, - providerName schemas.ModelProvider, model string, + created int, ) *schemas.BifrostChatResponse { response := &schemas.BifrostChatResponse{ - ID: id, - Object: "chat.completion.chunk", - Usage: usage, + ID: id, + Model: model, + Created: created, + Object: "chat.completion.chunk", + Usage: usage, Choices: []schemas.BifrostResponseChoice{ { FinishReason: finishReason, @@ -2190,10 +2191,7 @@ func CreateBifrostChatCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: requestType, - Provider: providerName, - ModelRequested: model, - ChunkIndex: currentChunkIndex + 1, + ChunkIndex: currentChunkIndex + 1, }, } return response @@ -2281,7 +2279,7 @@ func GetBifrostResponseForStreamResponse( transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse, imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse, ) *schemas.BifrostResponse { - //TODO add bifrost response pooling here + // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{} switch { @@ -2363,10 +2361,7 @@ func aggregateListModelsResponses(responses []*schemas.BifrostListModelsResponse // extractSuccessfulListModelsResponses extracts successful responses from a results channel // and tracks per-key status information. This utility reduces code duplication across providers // for handling multi-key ListModels requests. -func extractSuccessfulListModelsResponses( - results chan schemas.ListModelsByKeyResult, - providerName schemas.ModelProvider, -) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { +func extractSuccessfulListModelsResponses(results chan schemas.ListModelsByKeyResult, provider schemas.ModelProvider) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { var successfulResponses []*schemas.BifrostListModelsResponse var keyStatuses []schemas.KeyStatus var lastError *schemas.BifrostError @@ -2384,7 +2379,7 @@ func extractSuccessfulListModelsResponses( getLogger().Warn(fmt.Sprintf("failed to list models with key %s: %s", result.KeyID, errMsg)) keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusListModelsFailed, Error: result.Err, }) @@ -2394,7 +2389,7 @@ func extractSuccessfulListModelsResponses( keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, - Provider: providerName, + Provider: provider, Status: schemas.KeyStatusSuccess, }) successfulResponses = append(successfulResponses, result.Response) @@ -2409,10 +2404,6 @@ func extractSuccessfulListModelsResponses( Error: &schemas.ErrorField{ Message: "all keys failed to list models", }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - RequestType: schemas.ListModelsRequest, - }, } } @@ -2470,6 +2461,21 @@ func HandleMultipleListModelsRequests( wg.Add(1) go func(k schemas.Key) { defer wg.Done() + // Should never panic, but if it does, we need to handle it gracefully + defer func() { + if r := recover(); r != nil { + getLogger().Error("panic in listModelsByKey for key %s (%s): %v", k.Name, k.ID, r) + results <- schemas.ListModelsByKeyResult{ + Err: &schemas.BifrostError{ + IsBifrostError: true, + Error: &schemas.ErrorField{ + Message: "internal error while listing models for key", + }, + }, + KeyID: k.ID, + } + } + }() resp, bifrostErr := listModelsByKey(ctx, k, request) results <- schemas.ListModelsByKeyResult{Response: resp, Err: bifrostErr, KeyID: k.ID} }(key) @@ -2495,8 +2501,6 @@ func HandleMultipleListModelsRequests( // Set ExtraFields latency := time.Since(startTime) - response.ExtraFields.Provider = request.Provider - response.ExtraFields.RequestType = schemas.ListModelsRequest response.ExtraFields.Latency = latency.Milliseconds() return response, nil @@ -2557,9 +2561,8 @@ func GetReasoningEffortFromBudgetTokens( } } -// GetBudgetTokensFromReasoningEffort converts OpenAI reasoning effort -// into a reasoning token budget. -// effort ∈ {"none", "minimal", "low", "medium", "high"} +// GetBudgetTokensFromReasoningEffort converts reasoning effort into a reasoning token budget. +// effort ∈ {"none", "minimal", "low", "medium", "high", "xhigh", "max"} func GetBudgetTokensFromReasoningEffort( effort string, minBudgetTokens int, @@ -2589,6 +2592,10 @@ func GetBudgetTokensFromReasoningEffort( ratio = 0.425 case "high": ratio = 0.80 + case "xhigh": + ratio = 0.92 + case "max": + ratio = 1.0 default: // Unknown effort → safe default ratio = 0.425 @@ -2653,10 +2660,10 @@ func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostRe if accumulatedResp != nil { // Use accumulated response for attributes (includes full content, tool calls, etc.) - tracer.PopulateLLMResponseAttributes(handle, accumulatedResp, err) + tracer.PopulateLLMResponseAttributes(ctx, handle, accumulatedResp, err) } else if result != nil { // Fall back to final chunk if no accumulated data (shouldn't happen normally) - tracer.PopulateLLMResponseAttributes(handle, result, err) + tracer.PopulateLLMResponseAttributes(ctx, handle, result, err) } // Finalize aggregated post-hook spans before ending the LLM span diff --git a/core/providers/utils/utils_test.go b/core/providers/utils/utils_test.go index 5ca567fce0..e832980f4f 100644 --- a/core/providers/utils/utils_test.go +++ b/core/providers/utils/utils_test.go @@ -1107,7 +1107,8 @@ func TestBuildClientStreamChunk_ImageGenerationStripping(t *testing.T) { t.Run("logging-only: raw fields stripped from image gen chunk, original preserved", func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) - ctx.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true) chunk := BuildClientStreamChunk(ctx, response, nil) if chunk.BifrostImageGenerationStreamResponse == nil { @@ -1148,9 +1149,9 @@ func TestBuildClientStreamChunk_ImageGenerationStripping(t *testing.T) { } // TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk verifies -// that when BifrostContextKeyRawRequestResponseForLogging is set, ProcessAndSendResponse -// strips RawRequest and RawResponse from the outgoing stream chunk, while leaving other -// ExtraFields intact. It also verifies that the original BifrostResponse is not mutated +// that when drop-raw context flags are set, ProcessAndSendResponse strips RawRequest and +// RawResponse from the outgoing stream chunk, while leaving other ExtraFields intact. +// It also verifies that the original BifrostResponse is not mutated // (shared object safety for PostLLMHook goroutines). func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk(t *testing.T) { rawReq := json.RawMessage(`{"model":"gpt-4","messages":[]}`) @@ -1177,7 +1178,8 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChu t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) if tt.loggingOnly { - ctx.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true) } response := &schemas.BifrostResponse{ @@ -1237,9 +1239,9 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChu } // TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk verifies -// that when BifrostContextKeyRawRequestResponseForLogging is set, raw data is stripped -// from BifrostError payloads embedded in stream chunks, without mutating the shared -// BifrostError object (shared object safety for PostLLMHook goroutines). +// that when drop-raw context flags are set, raw data is stripped from BifrostError +// payloads embedded in stream chunks, without mutating the shared BifrostError object +// (shared object safety for PostLLMHook goroutines). func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk(t *testing.T) { rawReq := json.RawMessage(`{"model":"gpt-4"}`) rawResp := json.RawMessage(`{"error":"rate limit exceeded"}`) @@ -1265,7 +1267,8 @@ func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk( t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) if tt.loggingOnly { - ctx.SetValue(schemas.BifrostContextKeyRawRequestResponseForLogging, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true) + ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true) } // Use a postHookRunner that converts the response to a BifrostError with raw data diff --git a/core/providers/vertex/embedding.go b/core/providers/vertex/embedding.go index 0fc0ad598f..54662f50fe 100644 --- a/core/providers/vertex/embedding.go +++ b/core/providers/vertex/embedding.go @@ -110,8 +110,6 @@ func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.B Data: embeddings, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - RequestType: schemas.EmbeddingRequest, - Provider: schemas.Vertex, }, } } diff --git a/core/providers/vertex/errors.go b/core/providers/vertex/errors.go index 6b255835d4..e0ed7f1d3d 100644 --- a/core/providers/vertex/errors.go +++ b/core/providers/vertex/errors.go @@ -10,25 +10,13 @@ import ( "github.com/valyala/fasthttp" ) -func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetadata) *schemas.BifrostError { - var providerName schemas.ModelProvider - if meta != nil { - providerName = meta.Provider - } - +func parseVertexError(resp *fasthttp.Response) *schemas.BifrostError { var openAIErr schemas.BifrostError var vertexErr []VertexError decodedBody, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) return bifrostErr } @@ -42,13 +30,6 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseEmpty, }, } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } return bifrostErr } @@ -61,26 +42,20 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada Message: schemas.ErrProviderResponseHTML, Error: errors.New(string(decodedBody)), }, - } - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + ExtraFields: schemas.BifrostErrorExtraFields{ + RawResponse: string(decodedBody), + }, } return bifrostErr } createError := func(message string) *schemas.BifrostError { - bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), providerName, nil, nil) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } + bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), nil, nil) + var rawResponse interface{} + if err := sonic.Unmarshal(decodedBody, &rawResponse); err != nil { + rawResponse = string(decodedBody) } + bifrostErr.ExtraFields.RawResponse = rawResponse return bifrostErr } @@ -93,14 +68,7 @@ func parseVertexError(resp *fasthttp.Response, meta *providerUtils.RequestMetada // Try VertexValidationError format (validation errors from Mistral endpoint) var validationErr VertexValidationError if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil { - bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) - if meta != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: meta.Provider, - ModelRequested: meta.Model, - RequestType: meta.RequestType, - } - } + bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err) return bifrostErr } if len(validationErr.Detail) > 0 { diff --git a/core/providers/vertex/models.go b/core/providers/vertex/models.go index 28b5598022..2fbe83979d 100644 --- a/core/providers/vertex/models.go +++ b/core/providers/vertex/models.go @@ -1,12 +1,10 @@ package vertex import ( - "slices" "strings" + providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" - "golang.org/x/text/cases" - "golang.org/x/text/language" ) // VertexRankRequest represents the Discovery Engine rank API request. @@ -56,49 +54,6 @@ type vertexRerankOptions struct { UserLabels map[string]string } -// formatDeploymentName converts a deployment alias into a human-readable name. -// It splits the alias by "-" or "_", capitalizes each word, and joins them with spaces. -// Example: "gemini-pro" → "Gemini Pro", "claude_3_opus" → "Claude 3 Opus" -func formatDeploymentName(alias string) string { - caser := cases.Title(language.English) - - // Try splitting by hyphen first, then underscore - var parts []string - if strings.Contains(alias, "-") { - parts = strings.Split(alias, "-") - } else if strings.Contains(alias, "_") { - parts = strings.Split(alias, "_") - } else { - // No delimiter found, just capitalize the whole string - return caser.String(strings.ToLower(alias)) - } - - // Capitalize each part - for i, part := range parts { - if part != "" { - parts[i] = caser.String(strings.ToLower(part)) - } - } - - return strings.Join(parts, " ") -} - -// findDeploymentMatch finds a matching deployment value in the deployments map. -// Returns the deployment value and alias if found, empty strings otherwise. -func findDeploymentMatch(deployments map[string]string, customModelID string) (deploymentValue, alias string) { - // Check exact match by deployment value - for aliasKey, depValue := range deployments { - if depValue == customModelID { - return depValue, aliasKey - } - } - // Check exact match by alias/key - if deployment, ok := deployments[customModelID]; ok { - return deployment, customModelID - } - return "", "" -} - // ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format. // It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels). // @@ -114,7 +69,7 @@ func findDeploymentMatch(deployments map[string]string, customModelID string) (d // - If allowedModels is empty, all models are allowed // - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included // - Deployments map is used to match model IDs to aliases and filter accordingly -func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels []string, deployments map[string]string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -123,10 +78,22 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod Data: make([]schemas.Model, 0, len(response.Models)), } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) - // First pass: Process all models from the Vertex AI API response (custom models) + // Process all models from the Vertex AI API response (custom deployed models). + // The model ID is extracted from the endpoint URL last segment. for _, model := range response.Models { if len(model.DeployedModels) == 0 { continue @@ -142,110 +109,28 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod continue } - // Filter if model is not present in both lists (when both are non-empty) - // Empty lists mean "allow all" for that dimension - var deploymentValue, deploymentAlias string - shouldFilter := false - if !unfiltered && len(allowedModels) > 0 && len(deployments) > 0 { - // Both lists are present: model must be in allowedModels AND deployments - // AND the deployment alias must also be in allowedModels - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - inDeployments := deploymentAlias != "" - - // Check if deployment alias is also in allowedModels (direct string match) - deploymentAliasInAllowedModels := false - if deploymentAlias != "" { - deploymentAliasInAllowedModels = slices.Contains(allowedModels, deploymentAlias) + for _, result := range pipeline.FilterModel(customModelID) { + resolvedKey := strings.ToLower(result.ResolvedID) + if included[resolvedKey] { + continue } - - // Filter if: model not in deployments OR deployment alias not in allowedModels - shouldFilter = !inDeployments || !deploymentAliasInAllowedModels - } else if !unfiltered && len(allowedModels) > 0 { - // Only allowedModels is present: filter if model is not in allowedModels - shouldFilter = !slices.Contains(allowedModels, customModelID) - } else if !unfiltered && len(deployments) > 0 { - // Only deployments is present: filter if model is not in deployments - deploymentValue, deploymentAlias = findDeploymentMatch(deployments, customModelID) - shouldFilter = deploymentValue == "" - } - // If both are empty, shouldFilter remains false (allow all) - - if shouldFilter { - continue - } - - modelID := customModelID - - if !unfiltered && (slices.Contains(blacklistedModels, customModelID) || slices.Contains(blacklistedModels, deploymentAlias)) { - continue - } - - modelEntry := schemas.Model{ - ID: string(schemas.Vertex) + "/" + modelID, - Name: schemas.Ptr(model.DisplayName), - Description: schemas.Ptr(model.Description), - Created: schemas.Ptr(model.VersionCreateTime.Unix()), - } - // Set deployment info if matched via deployments - if deploymentValue != "" && deploymentAlias != "" { - modelEntry.ID = string(schemas.Vertex) + "/" + deploymentAlias - modelEntry.Deployment = schemas.Ptr(deploymentValue) - } - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelEntry.ID] = true - } - } - - // Second pass: Backfill deployments that were not matched from the API response - if !unfiltered && len(deployments) > 0 { - for alias, deploymentValue := range deployments { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + alias - if addedModelIDs[modelID] { - continue - } - // If allowedModels is non-empty, only include if alias is in the list - if len(allowedModels) > 0 && !slices.Contains(allowedModels, alias) { - continue - } - if slices.Contains(blacklistedModels, alias) { - continue - } - - modelName := formatDeploymentName(alias) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(model.DisplayName), + Description: schemas.Ptr(model.Description), + Created: schemas.Ptr(model.VersionCreateTime.Unix()), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[resolvedKey] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true } } - // Third pass: Backfill allowed models that were not in the response or deployments - if !unfiltered && len(allowedModels) > 0 { - for _, allowedModel := range allowedModels { - // Check if model already exists in the list - modelID := string(schemas.Vertex) + "/" + allowedModel - if addedModelIDs[modelID] { - continue - } - if slices.Contains(blacklistedModels, allowedModel) { - continue - } - - modelName := formatDeploymentName(allowedModel) - modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[modelID] = true - } - } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) bifrostResponse.NextPageToken = response.NextPageToken @@ -254,7 +139,7 @@ func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedMod // ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format. // This is for foundation models from the Model Garden (publishers.models.list endpoint). -func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels []string, blacklistedModels []string, unfiltered bool) *schemas.BifrostListModelsResponse { +func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse { if response == nil { return nil } @@ -263,8 +148,19 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a Data: make([]schemas.Model, 0, len(response.PublisherModels)), } - // Track which model IDs have been added to avoid duplicates - addedModelIDs := make(map[string]bool) + pipeline := &providerUtils.ListModelsPipeline{ + AllowedModels: allowedModels, + BlacklistedModels: blacklistedModels, + Aliases: aliases, + Unfiltered: unfiltered, + ProviderKey: schemas.Vertex, + MatchFns: providerUtils.DefaultMatchFns(), + } + if pipeline.ShouldEarlyExit() { + return bifrostResponse + } + + included := make(map[string]bool) for _, model := range response.PublisherModels { // Extract model ID from name (format: "publishers/google/models/gemini-1.5-pro") @@ -273,36 +169,28 @@ func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(a continue } - // Filter based on allowedModels if specified - if !unfiltered && len(allowedModels) > 0 && !slices.Contains(allowedModels, modelID) { - continue - } - if !unfiltered && slices.Contains(blacklistedModels, modelID) { - continue - } - - // Skip if already added (shouldn't happen, but safety check) - fullModelID := string(schemas.Vertex) + "/" + modelID - if addedModelIDs[fullModelID] { - continue - } - - // Extract display name from supported actions if available - displayName := modelID - if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { - displayName = model.SupportedActions.Deploy.ModelDisplayName - } - - modelEntry := schemas.Model{ - ID: fullModelID, - Name: schemas.Ptr(displayName), + for _, result := range pipeline.FilterModel(modelID) { + // Extract display name from supported actions if available + displayName := result.ResolvedID + if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" { + displayName = model.SupportedActions.Deploy.ModelDisplayName + } + modelEntry := schemas.Model{ + ID: string(schemas.Vertex) + "/" + result.ResolvedID, + Name: schemas.Ptr(displayName), + } + if result.AliasValue != "" { + modelEntry.Alias = schemas.Ptr(result.AliasValue) + } + bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) + included[strings.ToLower(result.ResolvedID)] = true } - - bifrostResponse.Data = append(bifrostResponse.Data, modelEntry) - addedModelIDs[fullModelID] = true } + bifrostResponse.Data = append(bifrostResponse.Data, + pipeline.BackfillModels(included)...) + bifrostResponse.NextPageToken = response.NextPageToken return bifrostResponse -} +} \ No newline at end of file diff --git a/core/providers/vertex/rerank.go b/core/providers/vertex/rerank.go index 74372658b2..b06430fcac 100644 --- a/core/providers/vertex/rerank.go +++ b/core/providers/vertex/rerank.go @@ -83,7 +83,7 @@ func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) } // ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format. -func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployment string, options *vertexRerankOptions) (*VertexRankRequest, error) { +func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost rerank request is nil") } @@ -132,7 +132,7 @@ func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, modelDeployme rankRequest.TopN = &topN } - if trimmedModel := strings.TrimSpace(modelDeployment); trimmedModel != "" { + if trimmedModel := strings.TrimSpace(bifrostReq.Model); trimmedModel != "" { rankRequest.Model = &trimmedModel } diff --git a/core/providers/vertex/rerank_test.go b/core/providers/vertex/rerank_test.go index afd8ed225e..3f2efcec52 100644 --- a/core/providers/vertex/rerank_test.go +++ b/core/providers/vertex/rerank_test.go @@ -42,7 +42,6 @@ func TestToVertexRankRequest(t *testing.T) { TopN: schemas.Ptr(10), }, }, - "semantic-ranker-default@latest", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, @@ -77,7 +76,6 @@ func TestToVertexRankRequestTooManyRecords(t *testing.T) { Query: "q", Documents: docs, }, - "", &vertexRerankOptions{ RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config", IgnoreRecordDetailsInResponse: true, diff --git a/core/providers/vertex/types.go b/core/providers/vertex/types.go index 97d6de7fa2..bbdb89d17f 100644 --- a/core/providers/vertex/types.go +++ b/core/providers/vertex/types.go @@ -192,23 +192,23 @@ type VertexModelLabels struct { // These types are for the publishers.models.list endpoint (Model Garden) type VertexPublisherModel struct { - Name string `json:"name"` - VersionID string `json:"versionId"` - OpenSourceCategory string `json:"openSourceCategory"` - LaunchStage string `json:"launchStage"` - VersionState string `json:"versionState"` - PublisherModelTemplate string `json:"publisherModelTemplate"` - SupportedActions *VertexPublisherModelActions `json:"supportedActions"` + Name string `json:"name"` + VersionID string `json:"versionId"` + OpenSourceCategory string `json:"openSourceCategory"` + LaunchStage string `json:"launchStage"` + VersionState string `json:"versionState"` + PublisherModelTemplate string `json:"publisherModelTemplate"` + SupportedActions *VertexPublisherModelActions `json:"supportedActions"` } type VertexPublisherModelActions struct { - OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` - OpenGenie *VertexPublisherModelURI `json:"openGenie"` - OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` - OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` - OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` - Deploy *VertexPublisherModelDeploy `json:"deploy"` - OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` + OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"` + OpenGenie *VertexPublisherModelURI `json:"openGenie"` + OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"` + OpenNotebook *VertexPublisherModelURI `json:"openNotebook"` + OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"` + Deploy *VertexPublisherModelDeploy `json:"deploy"` + OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"` } type VertexPublisherModelURI struct { diff --git a/core/providers/vertex/utils.go b/core/providers/vertex/utils.go index 2e4e225f4d..0dfda6763d 100644 --- a/core/providers/vertex/utils.go +++ b/core/providers/vertex/utils.go @@ -9,7 +9,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, providerName schemas.ModelProvider, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { +func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) { // Large payload mode: body streams directly from the LP reader — skip all body building // (matches CheckContextAndGetRequestBody guard). if providerUtils.IsLargePayloadPassthroughEnabled(ctx) { @@ -26,74 +26,74 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Add max_tokens if not present if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens)) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add stream if streaming if isStreaming { jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305) jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if err != nil { - return nil, providerUtils.NewBifrostOperationError(err.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(err.Error(), nil) } // Add anthropic_version if not present if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } else { // Validate tools are supported by Vertex if request.Params != nil && request.Params.Tools != nil { if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil { - return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil) } } // Convert request to Anthropic format reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request) if convErr != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr) } if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil) } reqBody.Model = deployment @@ -109,44 +109,44 @@ func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *s // Marshal struct to JSON bytes jsonBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Add anthropic_version if not present (using sjson to preserve order) if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if isCountTokens { jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } else { // Remove model field for Vertex API (it's in URL) jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region") if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 { jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } @@ -178,29 +178,25 @@ func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID // buildResponseFromConfig builds a list models response from configured deployments and allowedModels. // This is used when the user has explicitly configured which models they want to use. -func buildResponseFromConfig(deployments map[string]string, allowedModels []string, blacklistedModels []string) *schemas.BifrostListModelsResponse { +func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse { response := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } + if blacklistedModels.IsBlockAll() { + return response + } + addedModelIDs := make(map[string]bool) - // Build allowlist set for O(1) lookup - allowedSet := make(map[string]bool, len(allowedModels)) - for _, m := range allowedModels { - allowedSet[m] = true - } - blacklistedSet := make(map[string]bool, len(blacklistedModels)) - for _, m := range blacklistedModels { - blacklistedSet[m] = true - } + restrictAllowed := allowedModels.IsRestricted() // First add models from deployments (filtered by allowedModels when set) for alias, deploymentValue := range deployments { - if len(allowedSet) > 0 && !allowedSet[alias] { + if restrictAllowed && !allowedModels.Contains(alias) { continue } - if len(blacklistedSet) > 0 && blacklistedSet[alias] { + if blacklistedModels.IsBlocked(alias) { continue } modelID := string(schemas.Vertex) + "/" + alias @@ -208,28 +204,31 @@ func buildResponseFromConfig(deployments map[string]string, allowedModels []stri continue } - modelName := formatDeploymentName(alias) + modelName := providerUtils.ToDisplayName(alias) modelEntry := schemas.Model{ - ID: modelID, - Name: schemas.Ptr(modelName), - Deployment: schemas.Ptr(deploymentValue), + ID: modelID, + Name: schemas.Ptr(modelName), + Alias: schemas.Ptr(deploymentValue), } response.Data = append(response.Data, modelEntry) addedModelIDs[modelID] = true } - // Then add models from allowedModels that aren't already in deployments + // Then add models from allowedModels that aren't already in deployments (only when restricted) + if !restrictAllowed { + return response + } for _, allowedModel := range allowedModels { - if len(blacklistedSet) > 0 && blacklistedSet[allowedModel] { - continue - } modelID := string(schemas.Vertex) + "/" + allowedModel if addedModelIDs[modelID] { continue } + if blacklistedModels.IsBlocked(allowedModel) { + continue + } - modelName := formatDeploymentName(allowedModel) + modelName := providerUtils.ToDisplayName(allowedModel) modelEntry := schemas.Model{ ID: modelID, Name: schemas.Ptr(modelName), diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 37e7af90f2..9a4792eb6a 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -114,9 +114,6 @@ const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform" // It uses the JWT config if auth credentials are provided. // It returns an error if the token source creation fails. func getAuthTokenSource(key schemas.Key) (oauth2.TokenSource, error) { - if key.VertexKeyConfig == nil { - return nil, fmt.Errorf("vertex key config is not set") - } authCredentials := key.VertexKeyConfig.AuthCredentials var tokenSource oauth2.TokenSource if authCredentials.GetValue() == "" { @@ -176,23 +173,21 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // 1. If deployments or allowedModels are configured, return those (no API call needed) // 2. Otherwise, fetch from the publishers.models.list API endpoint (Model Garden) func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - deployments := key.VertexKeyConfig.Deployments + deployments := key.Aliases allowedModels := key.Models + if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) { + return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil + } + // If deployments or allowedModels are configured, return those directly without API call // Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved - if !request.Unfiltered && (len(deployments) > 0 || len(allowedModels) > 0) { + if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) { return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil } @@ -213,11 +208,11 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err) } // Iterate over all supported Vertex publishers to include Google, Anthropic, and Mistral models @@ -246,13 +241,14 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) if bifrostErr != nil { wait() + respBody := append([]byte(nil), resp.Body()...) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) // Non-Google publishers may not be available in all regions; skip on error if publisher != "google" { break } - return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp)) @@ -280,9 +276,9 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key var errorResp VertexError if err := sonic.Unmarshal(respBody, &errorResp); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, schemas.Vertex, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse Vertex's publisher models response @@ -322,7 +318,7 @@ func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key PublisherModels: allPublisherModels, } - response := aggregatedResponse.ToBifrostListModelsResponse(nil, key.BlacklistedModels, request.Unfiltered) + response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { response.ExtraFields.RawRequest = rawRequests @@ -368,18 +364,6 @@ func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -389,7 +373,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use centralized Anthropic converter reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request) if convErr != nil { @@ -399,7 +383,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) // Marshal to JSON bytes, preserving struct field order @@ -426,7 +409,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to delete model field: %w", err) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody, err := gemini.ToGeminiChatCompletionRequest(request) if err != nil { return nil, err @@ -435,7 +418,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes @@ -450,7 +432,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Marshal to JSON bytes rawBody, err = providerUtils.MarshalSorted(reqBody) if err != nil { @@ -465,26 +446,26 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Remap unsupported tool versions for Vertex (handles raw passthrough bodies) - if schemas.IsAnthropicModel(deployment) && jsonBody != nil { + if schemas.IsAnthropicModel(request.Model) && jsonBody != nil { remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } jsonBody = remappedBody } @@ -493,43 +474,43 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsAnthropicModel(deployment) { + } else if schemas.IsAnthropicModel(request.Model) { // Claude models use Anthropic publisher if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsMistralModel(deployment) { + } else if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with rawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { // Gemini models support api key if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } else { if region == "global" { @@ -564,11 +545,11 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -597,14 +578,10 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -613,16 +590,13 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return &schemas.BifrostChatResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ChatCompletionRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Create response object from pool anthropicResponse := anthropic.AcquireAnthropicMessageResponse() defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse) @@ -636,17 +610,9 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key response := anthropicResponse.ToBifrostChatResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ChatCompletionRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), - } - - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment + Latency: latency.Milliseconds(), + ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), } - response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -659,7 +625,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -668,12 +634,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key } response := geminiResponse.ToBifrostChatResponse() - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -695,12 +655,6 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ChatCompletionRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -723,35 +677,17 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key // Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails. func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after + return nil, providerUtils.NewConfigurationError("region is not set in key config") } - postResponseConverter := func(response *schemas.BifrostChatResponse) *schemas.BifrostChatResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { // Use Anthropic-style streaming for Claude models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -766,8 +702,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment - reqBody.Stream = schemas.Ptr(true) + reqBody.Stream = new(true) // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) @@ -803,7 +738,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -813,15 +748,15 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext var remapErr error jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex) if remapErr != nil { - return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil, providerName) + return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil) } } var completeURL string if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -834,11 +769,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken @@ -855,15 +790,10 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), providerName, postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ChatCompletionStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { // Use Gemini-style streaming for Gemini models jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, @@ -876,12 +806,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if reqBody == nil { return nil, fmt.Errorf("chat completion input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -894,12 +823,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { @@ -918,11 +847,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } @@ -940,7 +869,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -949,12 +878,12 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsMistralModel(deployment) { + if schemas.IsMistralModel(request.Model) { // Mistral models use mistralai publisher with streamRawPredict if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, request.Model) } } else { // Other models use OpenAPI endpoint for gemini models @@ -974,22 +903,17 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } authHeader = map[string]string{ "Authorization": "Bearer " + token.AccessToken, } } - postRequestConverter := func(reqBody *openai.OpenAIChatRequest) *openai.OpenAIChatRequest { - reqBody.Model = deployment - return reqBody - } - // Use shared OpenAI streaming logic return openai.HandleOpenAIChatCompletionStreaming( ctx, @@ -1005,8 +929,8 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext nil, nil, nil, - postRequestConverter, - postResponseConverter, + nil, + nil, provider.logger, ) } @@ -1014,40 +938,28 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext // Responses performs a responses request to the Vertex API. func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Claude models use Anthropic publisher var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } // Create HTTP request for streaming @@ -1068,11 +980,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1100,14 +1012,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1115,9 +1023,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1137,13 +1042,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem response := anthropicResponse.ToBifrostResponsesResponse(ctx) response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.ResponsesRequest, - Provider: providerName, - ModelRequested: request.Model, - Latency: latency.Milliseconds(), + Latency: latency.Milliseconds(), } - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) // Set raw request if enabled if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -1154,12 +1055,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse } - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } return response, nil - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -1171,24 +1069,23 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -1198,11 +1095,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":generateContent") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":generateContent") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1227,11 +1124,11 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -1260,14 +1157,10 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ResponsesRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1275,9 +1168,6 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostResponsesResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ResponsesRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1292,16 +1182,9 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := geminiResponse.ToResponsesBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { response.ExtraFields.RawResponse = rawResponse @@ -1319,52 +1202,33 @@ func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schem } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response, nil } } // ResponsesStream performs a streaming responses request to the Vertex API. func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } - jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, deployment) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, deployment) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) } // Prepare headers for Vertex Anthropic @@ -1377,22 +1241,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // Adding authorization header tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Anthropic return anthropic.HandleAnthropicResponsesStream( ctx, @@ -1406,23 +1262,18 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse), provider.GetProviderKey(), postHookRunner, - postResponseConverter, + nil, provider.logger, - &providerUtils.RequestMetadata{ - Provider: provider.GetProviderKey(), - Model: request.Model, - RequestType: schemas.ResponsesStreamRequest, - }, ) - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } // Use Gemini-style streaming for Gemini models @@ -1437,12 +1288,11 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if reqBody == nil { return nil, fmt.Errorf("responses input is not provided") } - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) return reqBody, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1455,12 +1305,12 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini streaming - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":streamGenerateContent") + completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent") // Add alt=sse parameter if authQuery != "" { completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery) @@ -1478,23 +1328,15 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if authQuery == "" { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } headers["Authorization"] = "Bearer " + token.AccessToken } - postResponseConverter := func(response *schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse { - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } - return response - } - // Use shared streaming logic from Gemini return gemini.HandleGeminiResponsesStream( ctx, @@ -1508,7 +1350,7 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos provider.GetProviderKey(), request.Model, postHookRunner, - postResponseConverter, + nil, provider.logger, ) } else { @@ -1526,18 +1368,14 @@ func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, pos // All Vertex AI embedding models use the same response format regardless of the model type. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1546,24 +1384,19 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVertexEmbeddingRequest(request), nil }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } - deployment := provider.getModelDeployment(key, request.Model) - - // Remove google/ prefix from deployment - deployment = strings.TrimPrefix(deployment, "google/") - // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Build the native Vertex embedding API endpoint - url := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predict") + url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict") // Create HTTP request for streaming req := fasthttp.AcquireRequest() @@ -1586,11 +1419,11 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1626,7 +1459,7 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Try to parse Vertex's error format var vertexError map[string]interface{} if err := sonic.Unmarshal(errBody, &vertexError); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } if errorObj, exists := vertexError["error"]; exists { @@ -1640,10 +1473,10 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem } } - return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), schemas.Vertex, nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1651,9 +1484,6 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostEmbeddingResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.EmbeddingRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1663,28 +1493,21 @@ func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schem // Parse Vertex's native embedding response using typed response var vertexResponse VertexEmbeddingResponse if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.Vertex), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Use centralized Vertex converter bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse() // Set ExtraFields - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) - if bifrostResponse.ExtraFields.ModelRequested != deployment { - bifrostResponse.ExtraFields.ModelDeployment = deployment - } - // Set raw response if enabled if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { var rawResponseMap map[string]interface{} if err := sonic.Unmarshal(resp.Body(), &rawResponseMap); err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err, providerName), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse.ExtraFields.RawResponse = rawResponseMap } @@ -1699,30 +1522,23 @@ func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas. // Rerank performs a rerank request using Vertex Discovery Engine ranking API. func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } options, err := getVertexRerankOptions(projectID, request.Params) if err != nil { - return nil, providerUtils.NewConfigurationError(err.Error(), providerName) + return nil, providerUtils.NewConfigurationError(err.Error()) } - modelDeployment := provider.getModelDeployment(key, request.Model) jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToVertexRankRequest(request, modelDeployment, options) + return ToVertexRankRequest(request, options) }, - providerName) + ) if bifrostErr != nil { return nil, bifrostErr } @@ -1748,11 +1564,11 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, providerName) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) @@ -1780,11 +1596,7 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. } errorMessage := parseDiscoveryEngineErrorMessage(resp.Body()) - parsedError := parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.RerankRequest, - }) + parsedError := parseVertexError(resp) if strings.TrimSpace(errorMessage) != "" { shouldOverride := parsedError == nil || @@ -1794,19 +1606,14 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. parsedError.Error.Message == schemas.ErrProviderResponseUnmarshal if shouldOverride { - parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), providerName, nil, nil) - parsedError.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, - } + parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil) } } return nil, providerUtils.EnrichError(ctx, parsedError, jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -1815,9 +1622,6 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. return &schemas.BifrostRerankResponse{ Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.RerankRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, @@ -1833,16 +1637,9 @@ func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas. returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments bifrostResponse, err := vertexResponse.ToBifrostRerankResponse(request.Documents, returnDocuments) if err != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - if request.Model != modelDeployment { - bifrostResponse.ExtraFields.ModelDeployment = modelDeployment - } - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -1878,21 +1675,9 @@ func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, } func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -1903,13 +1688,12 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -1917,7 +1701,7 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageGenerationRequest(request) if reqBody == nil { return nil, fmt.Errorf("image generation input is not provided") @@ -1937,58 +1721,58 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used for fine-tuned models to pass the API key in the query string authQuery := "" // Determine the URL based on model type var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { // Custom Fine-tuned models use OpenAPI endpoint projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { // Imagen models are published models, use publishers/google/models path if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if value := key.Value.GetValue(); value != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value)) } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2015,11 +1799,11 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2048,14 +1832,10 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageGenerationRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2063,16 +1843,13 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageGenerationRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2085,12 +1862,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2113,12 +1884,6 @@ func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageGenerationRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2142,20 +1907,9 @@ func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContex // ImageEdit edits images for the given input text(s) using Vertex AI. // Returns a BifrostResponse containing the images and any error that occurred. func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Validate model type before processing - if !schemas.IsGeminiModel(deployment) && !schemas.IsAllDigitsASCII(deployment) && !schemas.IsImagenModel(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", deployment), providerName) + if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", request.Model)) } jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody( @@ -2166,13 +1920,12 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem var extraParams map[string]interface{} var err error - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody := gemini.ToGeminiImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Model = deployment // Strip unsupported fields for Vertex Gemini stripVertexGeminiUnsupportedFields(reqBody) // Marshal to JSON bytes, preserving key order @@ -2180,7 +1933,7 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if err != nil { return nil, fmt.Errorf("failed to marshal request body: %w", err) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { reqBody := gemini.ToImagenImageEditRequest(request) if reqBody == nil { return nil, fmt.Errorf("image edit input is not provided") @@ -2200,19 +1953,19 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil }, - provider.GetProviderKey()) + ) if bifrostErr != nil { return nil, bifrostErr } projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" @@ -2221,27 +1974,27 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } var completeURL string - if schemas.IsAllDigitsASCII(deployment) { + if schemas.IsAllDigitsASCII(request.Model) { projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() if projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model) } - } else if schemas.IsImagenModel(deployment) { + } else if schemas.IsImagenModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model) } - } else if schemas.IsGeminiModel(deployment) { + } else if schemas.IsGeminiModel(request.Model) { if region == "global" { - completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, deployment) + completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model) } else { - completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, deployment) + completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model) } } @@ -2267,11 +2020,11 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem // Getting oauth2 token tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2299,14 +2052,10 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.ImageEditRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -2314,16 +2063,13 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem respOwned = false return &schemas.BifrostImageGenerationResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.ImageEditRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { geminiResponse := gemini.GenerateContentResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -2336,12 +2082,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) } - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2364,12 +2104,6 @@ func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schem } response := imagenResponse.ToBifrostImageGenerationResponse() - response.ExtraFields.RequestType = schemas.ImageEditRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -2400,18 +2134,9 @@ func (provider *VertexProvider) ImageVariation(ctx *schemas.BifrostContext, key func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { providerName := provider.GetProviderKey() - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, bifrostReq.Model) - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - // Only Gemini models support video generation in Vertex - if !schemas.IsVeoModel(deployment) && !schemas.IsAllDigitsASCII(deployment) { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", deployment), providerName) + if !schemas.IsVeoModel(bifrostReq.Model) && !schemas.IsAllDigitsASCII(bifrostReq.Model) { + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", bifrostReq.Model)) } // Convert Bifrost request to Gemini format (reusing Gemini converters) @@ -2421,7 +2146,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiVideoGenerationRequest(bifrostReq) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2429,12 +2153,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Auth query is used to pass the API key in the query string @@ -2445,12 +2169,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // For custom/fine-tuned models, validate projectNumber is set projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(bifrostReq.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } // Construct the URL for Gemini video generation using predictLongRunning - completeURL := getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":predictLongRunning") + completeURL := getCompleteURLForGeminiEndpoint(bifrostReq.Model, region, projectID, projectNumber, ":predictLongRunning") // Create HTTP request req := fasthttp.AcquireRequest() @@ -2469,11 +2193,11 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2493,17 +2217,13 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: bifrostReq.Model, - RequestType: schemas.VideoGenerationRequest, - }), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } // Parse response body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } var operation gemini.GenerateVideosOperation @@ -2519,12 +2239,6 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.ModelRequested = bifrostReq.Model - if bifrostReq.Model != deployment { - bifrostResp.ExtraFields.ModelDeployment = deployment - } - bifrostResp.ExtraFields.RequestType = schemas.VideoGenerationRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -2540,18 +2254,12 @@ func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key // VideoRetrieve retrieves the status of a video generation operation. // Uses the fetchPredictOperation endpoint for Vertex AI. func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } // Construct base URL based on region @@ -2567,12 +2275,12 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // projects/PROJECT_ID/locations/REGION/publishers/google/models/MODEL_ID/operations/OPERATION_ID // We need to extract the model path from it to construct the fetchPredictOperation endpoint // Extract: projects/.../models/MODEL_ID from the operation name - taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, providerName) + taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, provider.GetProviderKey()) var modelPath string if idx := strings.Index(taskID, "/operations/"); idx != -1 { modelPath = taskID[:idx] } else { - return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil) } // Construct the URL: https://REGION-aiplatform.googleapis.com/v1/{modelPath}:fetchPredictOperation @@ -2587,7 +2295,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // Create request body with operation name (using sjson to avoid map marshaling) jsonBody, err := providerUtils.SetJSONField([]byte(`{}`), "operationName", taskID) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err) } // Create HTTP request @@ -2607,11 +2315,11 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2631,10 +2339,7 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - RequestType: schemas.VideoRetrieveRequest, - }), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse) } // Parse response @@ -2648,10 +2353,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s if bifrostErr != nil { return nil, bifrostErr } - bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName) + bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey()) bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoRetrieveRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) if sendBackRawResponse { @@ -2665,9 +2368,8 @@ func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key s // First retrieves the video status to get the URL, then downloads the content. // Handles both regular URLs and data URLs (base64-encoded videos). func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() if request == nil || request.ID == "" { - return nil, providerUtils.NewBifrostOperationError("video_id is required", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video_id is required", nil) } // Retrieve operation first to get the video URL bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{ @@ -2681,12 +2383,10 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if videoResp.Status != schemas.VideoStatusCompleted { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("video not ready, current status: %s", videoResp.Status), - nil, - providerName, - ) + nil) } if len(videoResp.Videos) == 0 { - return nil, providerUtils.NewBifrostOperationError("video URL not available", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("video URL not available", nil) } var content []byte var latency time.Duration @@ -2698,7 +2398,7 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s startTime := time.Now() decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err, providerName) + return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err) } content = decoded contentType = videoResp.Videos[0].ContentType @@ -2728,11 +2428,11 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -2747,19 +2447,17 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s if resp.StatusCode() != fasthttp.StatusOK { return nil, providerUtils.NewBifrostOperationError( fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()), - nil, - providerName, - ) + nil) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } contentType = string(resp.Header.ContentType()) content = append([]byte(nil), body...) providerResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) } else { - return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil) } bifrostResp := &schemas.BifrostVideoDownloadResponse{ @@ -2769,8 +2467,6 @@ func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key s } bifrostResp.ExtraFields.Latency = latency.Milliseconds() - bifrostResp.ExtraFields.Provider = providerName - bifrostResp.ExtraFields.RequestType = schemas.VideoDownloadRequest bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders return bifrostResp, nil @@ -2808,19 +2504,6 @@ func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequ } } -func (provider *VertexProvider) getModelDeployment(key schemas.Key, model string) string { - if key.VertexKeyConfig == nil { - return model - } - - if key.VertexKeyConfig.Deployments != nil { - if deployment, ok := key.VertexKeyConfig.Deployments[model]; ok { - return deployment - } - } - return model -} - // BatchCreate is not supported by Vertex AI provider. func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey()) @@ -2880,25 +2563,13 @@ func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schem // CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint. // Supports Gemini models with both text and image content. func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewConfigurationError("vertex key config is not set", providerName) - } - - deployment := provider.getModelDeployment(key, request.Model) - // strip google/ prefix if present - if after, ok := strings.CutPrefix(deployment, "google/"); ok { - deployment = after - } - var ( jsonBody []byte bifrostErr *schemas.BifrostError ) - if schemas.IsAnthropicModel(deployment) { - jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, deployment, providerName, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) + if schemas.IsAnthropicModel(request.Model) { + jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders) if bifrostErr != nil { return nil, bifrostErr } @@ -2909,7 +2580,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch func() (providerUtils.RequestBodyWithExtraParams, error) { return gemini.ToGeminiResponsesRequest(request) }, - providerName, ) if bifrostErr != nil { return nil, bifrostErr @@ -2927,38 +2597,38 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch projectID := key.VertexKeyConfig.ProjectID.GetValue() if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", providerName) + return nil, providerUtils.NewConfigurationError("project ID is not set") } region := key.VertexKeyConfig.Region.GetValue() if region == "" { - return nil, providerUtils.NewConfigurationError("region is not set in key config", providerName) + return nil, providerUtils.NewConfigurationError("region is not set in key config") } authQuery := "" var completeURL string - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { if region == "global" { completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/count-tokens:rawPredict", projectID) } else { completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", region, projectID, region) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { if key.Value.GetValue() != "" { authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue())) } projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue() - if schemas.IsAllDigitsASCII(deployment) && projectNumber == "" { - return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models", providerName) + if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" { + return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models") } - completeURL = getCompleteURLForGeminiEndpoint(deployment, region, projectID, projectNumber, ":countTokens") + completeURL = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":countTokens") } if completeURL == "" { - return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model/deployment: %s", deployment), providerName) + return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model: %s", request.Model)) } req := fasthttp.AcquireRequest() @@ -2980,11 +2650,11 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } else { tokenSource, err := getAuthTokenSource(key) if err != nil { - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } req.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3012,14 +2682,10 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) } - return nil, providerUtils.EnrichError(ctx, parseVertexError(resp, &providerUtils.RequestMetadata{ - Provider: providerName, - Model: request.Model, - RequestType: schemas.CountTokensRequest, - }), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, providerName, provider.logger) + responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger) if decodeErr != nil { return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } @@ -3027,16 +2693,13 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch respOwned = false return &schemas.BifrostCountTokensResponse{ ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.CountTokensRequest, Latency: latency.Milliseconds(), ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp), }, }, nil } - if schemas.IsAnthropicModel(deployment) { + if schemas.IsAnthropicModel(request.Model) { anthropicResponse := &anthropic.AnthropicCountTokensResponse{} rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)) @@ -3045,12 +2708,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := anthropicResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3073,12 +2730,6 @@ func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key sch } response := vertexResponse.ToBifrostCountTokensResponse(request.Model) - response.ExtraFields.RequestType = schemas.CountTokensRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model - if request.Model != deployment { - response.ExtraFields.ModelDeployment = deployment - } response.ExtraFields.Latency = latency.Milliseconds() response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp) @@ -3143,14 +2794,9 @@ func (provider *VertexProvider) Passthrough( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { - - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3216,12 +2862,12 @@ func (provider *VertexProvider) Passthrough( tokenSource, err := getAuthTokenSource(key) if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3271,7 +2917,7 @@ func (provider *VertexProvider) Passthrough( body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) } for k := range headers { if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { @@ -3285,9 +2931,6 @@ func (provider *VertexProvider) Passthrough( } bifrostResponse.ExtraFields.ProviderResponseHeaders = headers - bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() - bifrostResponse.ExtraFields.RequestType = schemas.PassthroughRequest - bifrostResponse.ExtraFields.ModelRequested = req.Model bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3303,13 +2946,9 @@ func (provider *VertexProvider) PassthroughStream( key schemas.Key, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - if key.VertexKeyConfig == nil { - return nil, providerUtils.NewBifrostOperationError("vertex key config is not set", nil, schemas.Vertex) - } - projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue()) if projectID == "" { - return nil, providerUtils.NewConfigurationError("project ID is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("project ID is not set") } keyRegion := key.VertexKeyConfig.Region.GetValue() @@ -3375,13 +3014,13 @@ func (provider *VertexProvider) PassthroughStream( if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err) } token, err := tokenSource.Token() if err != nil { removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue()) providerUtils.ReleaseStreamingResponse(resp) - return nil, providerUtils.NewBifrostOperationError("error getting token", err, schemas.Vertex) + return nil, providerUtils.NewBifrostOperationError("error getting token", err) } fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken) } @@ -3421,9 +3060,9 @@ func (provider *VertexProvider) PassthroughStream( } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, provider.GetProviderKey()) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden { @@ -3438,9 +3077,7 @@ func (provider *VertexProvider) PassthroughStream( providerUtils.ReleaseStreamingResponse(resp) return nil, providerUtils.NewBifrostOperationError( "provider returned an empty stream body", - fmt.Errorf("provider returned an empty stream body"), - provider.GetProviderKey(), - ) + fmt.Errorf("provider returned an empty stream body")) } // Set stream idle timeout from provider config. @@ -3453,11 +3090,7 @@ func (provider *VertexProvider) PassthroughStream( // Cancellation must close the raw stream to unblock reads. stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger) - extraFields := schemas.BifrostResponseExtraFields{ - Provider: provider.GetProviderKey(), - ModelRequested: req.Model, - RequestType: schemas.PassthroughStreamRequest, - } + extraFields := schemas.BifrostResponseExtraFields{} statusCode := resp.StatusCode() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -3466,11 +3099,12 @@ func (provider *VertexProvider) PassthroughStream( ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.GetProviderKey(), req.Model, schemas.PassthroughStreamRequest, provider.logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger) } close(ch) }() @@ -3519,10 +3153,10 @@ func (provider *VertexProvider) PassthroughStream( } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) extraFields.Latency = time.Since(startTime).Milliseconds() - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, schemas.PassthroughStreamRequest, provider.GetProviderKey(), req.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger) return } } }() return ch, nil -} +} \ No newline at end of file diff --git a/core/providers/vertex/vertex_test.go b/core/providers/vertex/vertex_test.go index 7203bf3080..d754f33d22 100644 --- a/core/providers/vertex/vertex_test.go +++ b/core/providers/vertex/vertex_test.go @@ -27,7 +27,7 @@ func TestVertex(t *testing.T) { testConfig := llmtests.ComprehensiveTestConfig{ Provider: schemas.Vertex, - ChatModel: "google/gemini-2.0-flash-001", + ChatModel: "gemini-2.5-pro", PromptCachingModel: "claude-sonnet-4-5", VisionModel: "claude-sonnet-4-5", TextModel: "", // Vertex doesn't support text completion in newer models @@ -38,12 +38,12 @@ func TestVertex(t *testing.T) { ImageEditModel: "imagen-3.0-capability-001", VideoGenerationModel: "veo-3.1-generate-preview", Scenarios: llmtests.TestScenarios{ - TextCompletion: false, // Not supported - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: false, // Not supported + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, End2EndToolCalling: true, @@ -68,8 +68,10 @@ func TestVertex(t *testing.T) { PromptCaching: true, ListModels: false, CountTokens: true, - StructuredOutputs: true, // Structured outputs with nullable enum support - InterleavedThinking: true, + StructuredOutputs: true, // Structured outputs with nullable enum support + InterleavedThinking: true, + EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Vertex) + ServerToolsViaOpenAIEndpoint: true, // web_search only on Vertex per Table 20 (web_fetch/code_execution skip) }, } diff --git a/core/providers/vllm/utils.go b/core/providers/vllm/utils.go index d2cefce786..ab6d694938 100644 --- a/core/providers/vllm/utils.go +++ b/core/providers/vllm/utils.go @@ -13,9 +13,6 @@ func HandleVLLMResponse[T any](responseBody []byte, response *T, requestBody []b return rawRequest, rawResponse, bifrostErr } if err := sonic.Unmarshal(responseBody, &errorResp); err == nil && errorResp.Error != nil && errorResp.Error.Message != "" { - errorResp.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: schemas.VLLM, - } return rawRequest, rawResponse, &errorResp } return rawRequest, rawResponse, nil diff --git a/core/providers/vllm/vllm.go b/core/providers/vllm/vllm.go index 5cafe918af..548c6a0dc7 100644 --- a/core/providers/vllm/vllm.go +++ b/core/providers/vllm/vllm.go @@ -76,9 +76,7 @@ func (provider *VLLMProvider) baseURLOrError(key schemas.Key) (string, *schemas. if u == "" { return "", providerUtils.NewBifrostOperationError( "no base URL configured: set vllm_key_config.url on the key", - nil, - provider.GetProviderKey(), - ) + nil) } return u, nil } @@ -246,9 +244,6 @@ func (provider *VLLMProvider) Responses(ctx *schemas.BifrostContext, key schemas return nil, err } response := chatResponse.ToBifrostResponsesResponse() - response.ExtraFields.RequestType = schemas.ResponsesRequest - response.ExtraFields.Provider = provider.GetProviderKey() - response.ExtraFields.ModelRequested = request.Model return response, nil } @@ -314,12 +309,14 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( statusCode := resp.StatusCode() if statusCode != fasthttp.StatusOK { - return nil, nil, nil, nil, statusCode, latency, openai.ParseOpenAIError(resp, schemas.RerankRequest, provider.GetProviderKey(), request.Model) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, openai.ParseOpenAIError(resp) } body, err := providerUtils.CheckAndDecodeBody(resp) if err != nil { - return nil, nil, nil, nil, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err, provider.GetProviderKey()) + rawErrBody := append([]byte(nil), resp.Body()...) + return nil, nil, nil, rawErrBody, statusCode, latency, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err) } sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) @@ -336,16 +333,12 @@ func (provider *VLLMProvider) callVLLMRerankEndpoint( // Rerank performs a rerank request to vLLM's API. func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { - providerName := provider.GetProviderKey() - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { return ToVLLMRerankRequest(request), nil - }, - providerName, - ) + }) if bifrostErr != nil { return nil, bifrostErr } @@ -358,6 +351,9 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke resolvedPath = "/" + resolvedPath } + sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) + sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) + responsePayload, rawRequest, rawResponse, responseBody, statusCode, latency, bifrostErr := provider.callVLLMRerankEndpoint(ctx, key, request, resolvedPath, jsonData) if bifrostErr != nil && !hasPathOverride && isRerankFallbackStatus(statusCode) { var fallbackLatency time.Duration @@ -365,7 +361,7 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke latency += fallbackLatency } if bifrostErr != nil { - return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, responseBody, sendBackRawRequest, sendBackRawResponse) } returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments @@ -373,19 +369,16 @@ func (provider *VLLMProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Ke if err != nil { return nil, providerUtils.EnrichError( ctx, - providerUtils.NewBifrostOperationError("error converting rerank response", err, providerName), + providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonData, responseBody, - provider.sendBackRawRequest, - provider.sendBackRawResponse, + sendBackRawRequest, + sendBackRawResponse, ) } // Keep requested model as the canonical model in Bifrost response. bifrostResponse.Model = request.Model - bifrostResponse.ExtraFields.Provider = providerName - bifrostResponse.ExtraFields.ModelRequested = request.Model - bifrostResponse.ExtraFields.RequestType = schemas.RerankRequest bifrostResponse.ExtraFields.Latency = latency.Milliseconds() if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { @@ -440,7 +433,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Use centralized converter reqBody := openai.ToOpenAITranscriptionRequest(request) if reqBody == nil { - return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil, providerName) + return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil) } reqBody.Stream = schemas.Ptr(true) @@ -496,9 +489,9 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } } if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err, providerName) + return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err) } - return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err, providerName) + return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err) } // Store provider response headers in context before status check so error responses also forward them @@ -507,7 +500,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Check for HTTP errors if resp.StatusCode() != fasthttp.StatusOK { defer providerUtils.ReleaseStreamingResponse(resp) - return nil, openai.ParseOpenAIError(resp, schemas.TranscriptionStreamRequest, providerName, request.Model) + return nil, openai.ParseOpenAIError(resp) } // Large payload streaming passthrough — pipe raw upstream SSE to client @@ -524,11 +517,12 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p // Start streaming in a goroutine go func() { + defer providerUtils.EnsureStreamFinalizerCalled(ctx) defer func() { if ctx.Err() == context.Canceled { - providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger) } else if ctx.Err() == context.DeadlineExceeded { - providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, providerName, request.Model, schemas.TranscriptionStreamRequest, logger) + providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger) } close(responseChan) }() @@ -568,7 +562,7 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) logger.Warn("Error reading stream: %v", readErr) - providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger) } break } @@ -585,11 +579,6 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p _, _, bifrostErr = HandleVLLMResponse(dataBytes, &response, nil, false, false) if bifrostErr != nil { - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - Provider: providerName, - ModelRequested: request.Model, - RequestType: schemas.TranscriptionStreamRequest, - } ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), dataBytes, false, providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)), responseChan, logger) return @@ -608,11 +597,8 @@ func (provider *VLLMProvider) TranscriptionStream(ctx *schemas.BifrostContext, p } response.ExtraFields = schemas.BifrostResponseExtraFields{ - RequestType: schemas.TranscriptionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - ChunkIndex: chunkIndex, - Latency: time.Since(lastChunkTime).Milliseconds(), + ChunkIndex: chunkIndex, + Latency: time.Since(lastChunkTime).Milliseconds(), } lastChunkTime = time.Now() diff --git a/core/providers/vllm/vllm_test.go b/core/providers/vllm/vllm_test.go index a9d7a1c17d..2f1d5b22c6 100644 --- a/core/providers/vllm/vllm_test.go +++ b/core/providers/vllm/vllm_test.go @@ -37,35 +37,35 @@ func TestVLLM(t *testing.T) { EmbeddingModel: embeddingModel, RerankModel: rerankModel, Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - TextCompletionStream: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + TextCompletionStream: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: false, - ImageBase64: false, - MultipleImages: false, - CompleteEnd2End: true, - Embedding: true, - Rerank: rerankModel != "", - ListModels: true, - Reasoning: true, - SpeechSynthesis: false, - SpeechSynthesisStream: false, - Transcription: true, - TranscriptionStream: false, - ImageGeneration: false, - ImageGenerationStream: false, - ImageEdit: false, - ImageEditStream: false, - ImageVariation: false, - ImageVariationStream: false, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: false, + ImageBase64: false, + MultipleImages: false, + CompleteEnd2End: true, + Embedding: true, + Rerank: rerankModel != "", + ListModels: true, + Reasoning: true, + SpeechSynthesis: false, + SpeechSynthesisStream: false, + Transcription: true, + TranscriptionStream: false, + ImageGeneration: false, + ImageGenerationStream: false, + ImageEdit: false, + ImageEditStream: false, + ImageVariation: false, + ImageVariationStream: false, }, } diff --git a/core/providers/xai/errors.go b/core/providers/xai/errors.go index 78b22463e0..38a46888a8 100644 --- a/core/providers/xai/errors.go +++ b/core/providers/xai/errors.go @@ -15,7 +15,7 @@ type XAIErrorResponse struct { // ParseXAIError parses xAI-specific error responses. // xAI returns errors in format: {"code": "...", "error": "..."} // Unlike OpenAI which uses: {"error": {"message": "...", "type": "...", "code": "..."}} -func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, providerName schemas.ModelProvider, model string) *schemas.BifrostError { +func ParseXAIError(resp *fasthttp.Response) *schemas.BifrostError { // Try to parse xAI error format var xaiErr XAIErrorResponse bifrostErr := providerUtils.HandleProviderAPIError(resp, &xaiErr) @@ -35,10 +35,5 @@ func ParseXAIError(resp *fasthttp.Response, requestType schemas.RequestType, pro } } - // Set ExtraFields individually to preserve RawResponse from HandleProviderAPIError - bifrostErr.ExtraFields.Provider = providerName - bifrostErr.ExtraFields.ModelRequested = model - bifrostErr.ExtraFields.RequestType = requestType - return bifrostErr } diff --git a/core/providers/xai/xai.go b/core/providers/xai/xai.go index 6fa52d39f3..118b8589bf 100644 --- a/core/providers/xai/xai.go +++ b/core/providers/xai/xai.go @@ -65,7 +65,7 @@ func (provider *XAIProvider) GetProviderKey() schemas.ModelProvider { // ListModels performs a list models request to xAI's API. func (provider *XAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if provider.networkConfig.BaseURL == "" { - return nil, providerUtils.NewConfigurationError("base_url is not set", provider.GetProviderKey()) + return nil, providerUtils.NewConfigurationError("base_url is not set") } return openai.HandleOpenAIListModelsRequest( ctx, diff --git a/core/providers/xai/xai_test.go b/core/providers/xai/xai_test.go index 81c479bb7f..d7e0447d28 100644 --- a/core/providers/xai/xai_test.go +++ b/core/providers/xai/xai_test.go @@ -30,29 +30,29 @@ func TestXAI(t *testing.T) { TextModel: "grok-3", VisionModel: "grok-4-1-fast-reasoning", EmbeddingModel: "", // XAI doesn't support embedding - ImageGenerationModel: "grok-2-image", + ImageGenerationModel: "grok-imagine-image", Scenarios: llmtests.TestScenarios{ - TextCompletion: true, - SimpleChat: true, - CompletionStream: true, - MultiTurnConversation: true, - ToolCalls: true, - ToolCallsStreaming: true, + TextCompletion: true, + SimpleChat: true, + CompletionStream: true, + MultiTurnConversation: true, + ToolCalls: true, + ToolCallsStreaming: true, MultipleToolCalls: true, MultipleToolCallsStreaming: true, - End2EndToolCalling: true, - AutomaticFunctionCall: true, - ImageURL: true, - ImageBase64: true, - ImageGeneration: true, - ImageGenerationStream: false, - FileBase64: false, - FileURL: false, - MultipleImages: true, - CompleteEnd2End: true, - Reasoning: true, - Embedding: false, - ListModels: true, + End2EndToolCalling: true, + AutomaticFunctionCall: true, + ImageURL: true, + ImageBase64: true, + ImageGeneration: true, + ImageGenerationStream: false, + FileBase64: false, + FileURL: false, + MultipleImages: true, + CompleteEnd2End: true, + Reasoning: true, + Embedding: false, + ListModels: true, }, } diff --git a/core/schemas/account.go b/core/schemas/account.go index ceaeb2de8a..3dfb16a0b4 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -1,7 +1,12 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas -import "context" +import ( + "context" + "fmt" + "slices" + "strings" +) type KeyStatusType string @@ -10,26 +15,172 @@ const ( KeyStatusListModelsFailed KeyStatusType = "list_models_failed" ) +// WhiteList is a list of values that are allowed to be used. +// Semantics: +// - "*" (alone) means all values are allowed. +// - Empty list means nothing is allowed. +// - Non-empty list (without "*") means only the listed values are allowed. +// +// This type is used generically for any field that needs whitelist behavior +// (e.g., allowed models, allowed tools). +type WhiteList []string + +// Contains reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) Contains(value string) bool { + return slices.ContainsFunc(wl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsAllowed reports whether value is in the whitelist. +// Returns true if value is in the list. +func (wl WhiteList) IsAllowed(value string) bool { + return wl.IsUnrestricted() || wl.Contains(value) +} + +// IsEmpty reports whether the whitelist has no entries. +func (wl WhiteList) IsEmpty() bool { + return len(wl) == 0 +} + +// IsUnrestricted reports whether the whitelist contains only "*", +// meaning all values are allowed. +func (wl WhiteList) IsUnrestricted() bool { + return len(wl) == 1 && wl[0] == "*" +} + +// IsRestricted reports whether the whitelist contains entries other than "*", +// meaning only the listed values are allowed. +func (wl WhiteList) IsRestricted() bool { + return !wl.IsUnrestricted() +} + +// Validate checks that the whitelist is well-formed. +// Returns an error if "*" is present alongside other values, or if there are duplicate entries. +func (wl WhiteList) Validate() error { + if wl.Contains("*") && len(wl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the whitelist") + } + seen := make(map[string]struct{}, len(wl)) + for _, v := range wl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in whitelist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + +// BlackList is a list of values that are denied. +// Semantics: +// - "*" (alone) means all values are blocked. +// - Empty list means nothing is blocked. +// - Non-empty list (without "*") means only the listed values are blocked. +type BlackList []string + +func (bl BlackList) Contains(value string) bool { + return slices.ContainsFunc(bl, func(s string) bool { + return strings.EqualFold(s, value) + }) +} + +// IsBlocked reports whether value is blocked. +func (bl BlackList) IsBlocked(value string) bool { + return bl.IsBlockAll() || bl.Contains(value) +} + +// IsEmpty reports whether the blacklist has no entries (nothing is blocked). +func (bl BlackList) IsEmpty() bool { + return len(bl) == 0 +} + +// IsBlockAll reports whether the blacklist contains "*", meaning all values are blocked. +func (bl BlackList) IsBlockAll() bool { + return len(bl) == 1 && bl[0] == "*" +} + +// Validate checks that the blacklist is well-formed. +func (bl BlackList) Validate() error { + if bl.Contains("*") && len(bl) > 1 { + return fmt.Errorf("wildcard '*' cannot be used with other values in the blacklist") + } + seen := make(map[string]struct{}, len(bl)) + for _, v := range bl { + normalized := strings.ToLower(v) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate value '%s' in blacklist", v) + } + seen[normalized] = struct{}{} + } + return nil +} + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { - ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) - Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) - Value EnvVar `json:"value"` // The actual API key value - Models []string `json:"models"` // List of models this key can access - BlacklistedModels []string `json:"blacklisted_models"` // List of models this key cannot access - Weight float64 `json:"weight"` // Weight for load balancing between multiple keys - AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration - VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration - BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration - HuggingFaceKeyConfig *HuggingFaceKeyConfig `json:"huggingface_key_config,omitempty"` // Hugging Face-specific key configuration - ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration - VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration - Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) - UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) - ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection - Status KeyStatusType `json:"status,omitempty"` // Status of key - Description string `json:"description,omitempty"` // Description of key + ID string `json:"id"` // The unique identifier for the key (used by bifrost to identify the key) + Name string `json:"name"` // The name of the key (used by users to identify the key, not used by bifrost) + Value EnvVar `json:"value"` // The actual API key value + Models WhiteList `json:"models"` // List of models this key can access + BlacklistedModels BlackList `json:"blacklisted_models"` // List of models this key cannot access + Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + Aliases KeyAliases `json:"aliases,omitempty"` // Mapping of model identifiers to inference profiles + AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration + VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration + BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration + VLLMKeyConfig *VLLMKeyConfig `json:"vllm_key_config,omitempty"` // vLLM-specific key configuration + ReplicateKeyConfig *ReplicateKeyConfig `json:"replicate_key_config,omitempty"` // Replicate-specific key configuration + OllamaKeyConfig *OllamaKeyConfig `json:"ollama_key_config,omitempty"` // Ollama-specific key configuration + SGLKeyConfig *SGLKeyConfig `json:"sgl_key_config,omitempty"` // SGLang-specific key configuration + Enabled *bool `json:"enabled,omitempty"` // Whether the key is active (default:true) + UseForBatchAPI *bool `json:"use_for_batch_api,omitempty"` // Whether this key can be used for batch API operations (default:false for new keys, migrated keys default to true) + ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection + Status KeyStatusType `json:"status,omitempty"` // Status of key + Description string `json:"description,omitempty"` // Description of key +} + +type KeyAliases map[string]string + +func (ka KeyAliases) Validate() error { + seen := make(map[string]struct{}, len(ka)) + for from, to := range ka { + if strings.TrimSpace(from) == "" { + return fmt.Errorf("alias source cannot be empty") + } + if strings.TrimSpace(to) == "" { + return fmt.Errorf("alias target for %q cannot be empty", from) + } + if strings.TrimSpace(from) != from { + return fmt.Errorf("alias source %q cannot have leading or trailing whitespace", from) + } + if strings.TrimSpace(to) != to { + return fmt.Errorf("alias target for %q cannot have leading or trailing whitespace", from) + } + normalized := strings.ToLower(from) + if _, ok := seen[normalized]; ok { + return fmt.Errorf("duplicate alias source %q (case-insensitive)", from) + } + seen[normalized] = struct{}{} + } + return nil +} + +func (ka KeyAliases) Resolve(model string) string { + if ka == nil { + return model + } + if alias, ok := ka[model]; ok { + return alias + } + // Fall back to case-insensitive lookup for consistency with WhiteList.Contains + for k, v := range ka { + if strings.EqualFold(k, model) { + return v + } + } + return model } type AzureAuthType string @@ -42,9 +193,8 @@ const ( // AzureKeyConfig represents the Azure-specific configuration. // It contains Azure-specific settings required for service access and deployment management. type AzureKeyConfig struct { - Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" + Endpoint EnvVar `json:"endpoint"` // Azure service endpoint URL + APIVersion *EnvVar `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-10-21" ClientID *EnvVar `json:"client_id,omitempty"` // Azure client ID for authentication ClientSecret *EnvVar `json:"client_secret,omitempty"` // Azure client secret for authentication @@ -55,11 +205,10 @@ type AzureKeyConfig struct { // VertexKeyConfig represents the Vertex-specific configuration. // It contains Vertex-specific settings required for authentication and service access. type VertexKeyConfig struct { - ProjectID EnvVar `json:"project_id"` - ProjectNumber EnvVar `json:"project_number"` - Region EnvVar `json:"region"` - AuthCredentials EnvVar `json:"auth_credentials"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles + ProjectID EnvVar `json:"project_id"` + ProjectNumber EnvVar `json:"project_number"` + Region EnvVar `json:"region"` + AuthCredentials EnvVar `json:"auth_credentials"` } // NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string. @@ -90,21 +239,12 @@ type BedrockKeyConfig struct { ExternalID *EnvVar `json:"external_id,omitempty"` RoleSessionName *EnvVar `json:"session_name,omitempty"` - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles - BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations + BatchS3Config *BatchS3Config `json:"batch_s3_config,omitempty"` // S3 bucket configuration for batch operations } // NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings. // To use Bedrock API Key authentication, set Value in Key struct instead. -type HuggingFaceKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - -type ReplicateKeyConfig struct { - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to deployment names -} - // VLLMKeyConfig represents the vLLM-specific key configuration. // It allows each key to target a different vLLM server URL and model name, // enabling per-key routing and round-robin load balancing across multiple vLLM instances. @@ -113,6 +253,26 @@ type VLLMKeyConfig struct { ModelName string `json:"model_name"` // Exact model name served on this VLLM instance (used for key selection) } +// ReplicateKeyConfig represents the Replicate-specific key configuration. +// It contains Replicate-specific settings required for authentication and service access. +type ReplicateKeyConfig struct { + UseDeploymentsEndpoint bool `json:"use_deployments_endpoint"` // Whether to use the deployments endpoint instead of the models endpoint +} + +// OllamaKeyConfig represents the Ollama-specific key configuration. +// It allows each key to target a different Ollama server URL, +// enabling per-key routing and round-robin load balancing across multiple Ollama instances. +type OllamaKeyConfig struct { + URL EnvVar `json:"url"` // Ollama server base URL (required, supports env. prefix) +} + +// SGLKeyConfig represents the SGLang-specific key configuration. +// It allows each key to target a different SGLang server URL, +// enabling per-key routing and round-robin load balancing across multiple SGLang instances. +type SGLKeyConfig struct { + URL EnvVar `json:"url"` // SGLang server base URL (required, supports env. prefix) +} + // Account defines the interface for managing provider accounts and their configurations. // It provides methods to access provider-specific settings, API keys, and configurations. type Account interface { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index c00c5c7078..04cca3915a 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -161,35 +161,47 @@ type BifrostContextKey string // BifrostContextKeyRequestType is a context key for the request type. const ( - BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) - BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string - BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) - BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) - BifrostContextKeyRequestID BifrostContextKey = "request-id" // string - BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string - BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceUserID BifrostContextKey = "bifrost-governance-user-id" // string (to store the user ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) - BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) - BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string - BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string + BifrostContextKeySessionToken BifrostContextKey = "bifrost-session-token" // string (session token for authentication - set by auth middleware) + BifrostContextKeyVirtualKey BifrostContextKey = "x-bf-vk" // string + BifrostContextKeyAPIKeyName BifrostContextKey = "x-bf-api-key" // string (explicit key name selection) + BifrostContextKeyAPIKeyID BifrostContextKey = "x-bf-api-key-id" // string (explicit key ID selection, takes priority over name) + BifrostContextKeyRequestID BifrostContextKey = "request-id" // string + BifrostContextKeyFallbackRequestID BifrostContextKey = "fallback-request-id" // string + BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" // Key struct + + // NOTE: []string is used for both keys, and by default all clients/tools are included (when nil). + // If "*" is present, all clients/tools are included, and [] means no clients/tools are included. + // Request context filtering takes priority over client config - context can override client exclusions. + MCPContextKeyIncludeClients BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering + MCPContextKeyIncludeTools BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) + + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitID BifrostContextKey = "bifrost-governance-business-unit-id" // string (to store the business unit ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptName BifrostContextKey = "bifrost-selected-prompt-name" // string (display name of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptVersion BifrostContextKey = "bifrost-selected-prompt-version" // string (numeric version as string, e.g. "3" (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptID BifrostContextKey = "bifrost-selected-prompt-id" // string (id of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) + BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) + BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string + BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" - BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool - BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool + BifrostContextKeyChangeRequestType BifrostContextKey = "bifrost-change-request-type" // RequestType (set by plugins to trigger request type conversion in core, e.g. text->chat or chat->responses) + BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool (per-request override — read by bifrost.go, never overwritten) + BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool (per-request override — read by bifrost.go, never overwritten) BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) @@ -202,32 +214,51 @@ const ( BifrostContextKeyStreamStartTime BifrostContextKey = "bifrost-stream-start-time" // time.Time (start time for streaming TTFT calculation - set by bifrost) BifrostContextKeyTracer BifrostContextKey = "bifrost-tracer" // Tracer (tracer instance for completing deferred spans - set by bifrost) BifrostContextKeyDeferTraceCompletion BifrostContextKey = "bifrost-defer-trace-completion" // bool (signals trace completion should be deferred for streaming - set by streaming handlers) - BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) + BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func([]PluginLogEntry) (callback to complete trace after streaming, receives transport plugin logs - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) - BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) + BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) + BifrostContextKeyMCPUserID BifrostContextKey = "bifrost-mcp-user-id" // string (per-user OAuth user identifier from X-Bf-User-Id header) + BifrostContextKeyOAuthRedirectURI BifrostContextKey = "bifrost-oauth-redirect-uri" // string (OAuth callback URL, e.g. https://host/api/oauth/callback - set by HTTP middleware) + BifrostContextKeyIsMCPGateway BifrostContextKey = "bifrost-is-mcp-gateway" // bool (true when request is being handled via the MCP gateway path) + BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) + BifrostContextKeyPromptsPluginName BifrostContextKey = "prompts-plugin-name" // string (name of the prompts plugin to use - set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyAvailableProviders BifrostContextKey = "available-providers" // []ModelProvider (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyRawRequestResponseForLogging BifrostContextKey = "bifrost-raw-request-response-for-logging" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyStoreRawRequestResponse BifrostContextKey = "bifrost-store-raw-request-response" // bool (per-request override — read by bifrost.go, never overwritten) + BifrostContextKeyCaptureRawRequest BifrostContextKey = "bifrost-capture-raw-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when providers should capture raw request bytes + BifrostContextKeyCaptureRawResponse BifrostContextKey = "bifrost-capture-raw-response" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when providers should capture raw response bytes + BifrostContextKeyDropRawRequestFromClient BifrostContextKey = "bifrost-drop-raw-request-from-client" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when raw request should be stripped from the client-facing response + BifrostContextKeyDropRawResponseFromClient BifrostContextKey = "bifrost-drop-raw-response-from-client" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when raw response should be stripped from the client-facing response + BifrostContextKeyShouldStoreRawInLogs BifrostContextKey = "bifrost-should-store-raw-in-logs" // bool (set by bifrost - DO NOT SET THIS MANUALLY) — true when raw request/response should be persisted in log records BifrostContextKeyRetryDBFetch BifrostContextKey = "bifrost-retry-db-fetch" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyIsCustomProvider BifrostContextKey = "bifrost-is-custom-provider" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries + BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) + BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request + BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns) + BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string + BifrostContextKeyRealtimeProviderSessionID BifrostContextKey = "bifrost-realtime-provider-session-id" // string + BifrostContextKeyRealtimeSource BifrostContextKey = "bifrost-realtime-source" // string ("ei" or "lm") + BifrostContextKeyRealtimeEventType BifrostContextKey = "bifrost-realtime-event-type" // string BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeySCIMClaims BifrostContextKey = "scim_claims" - BifrostContextKeyUserID BifrostContextKey = "user_id" + BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) + BifrostContextKeyUserName BifrostContextKey = "bifrost-user-name" // string (to store the user name (set by enterprise auth middleware - DO NOT SET THIS MANUALLY)) BifrostContextKeyTargetUserID BifrostContextKey = "target_user_id" BifrostContextKeyIsAzureUserAgent BifrostContextKey = "bifrost-is-azure-user-agent" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an Azure user agent (only used in gateway) BifrostContextKeyVideoOutputRequested BifrostContextKey = "bifrost-video-output-requested" BifrostContextKeyValidateKeys BifrostContextKey = "bifrost-validate-keys" // bool (triggers additional key validation during provider add/update) BifrostContextKeyProviderResponseHeaders BifrostContextKey = "bifrost-provider-response-headers" // map[string]string (set by provider handlers for response header forwarding) + BifrostContextKeyMCPAddedTools BifrostContextKey = "bifrost-mcp-added-tools" // []string (set by bifrost - DO NOT SET THIS MANUALLY)) - list of tools added to the request by MCP, all the tool are in the format "clientName-toolName" BifrostContextKeyLargePayloadMode BifrostContextKey = "bifrost-large-payload-mode" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) indicates large payload streaming mode is active BifrostContextKeyLargePayloadReader BifrostContextKey = "bifrost-large-payload-reader" // io.Reader (set by bifrost - DO NOT SET THIS MANUALLY)) upstream reader for large payloads BifrostContextKeyLargePayloadContentLength BifrostContextKey = "bifrost-large-payload-content-length" // int (set by bifrost - DO NOT SET THIS MANUALLY)) content length for large payloads @@ -248,7 +279,13 @@ const ( BifrostContextKeySSEReaderFactory BifrostContextKey = "bifrost-sse-reader-factory" // *providerUtils.SSEReaderFactory (set by enterprise — replaces default bufio.Scanner SSE readers with streaming readers) BifrostContextKeySessionID BifrostContextKey = "bifrost-session-id" // string session ID for the request (session stickiness) BifrostContextKeySessionTTL BifrostContextKey = "bifrost-session-ttl" // time.Duration session TTL for the request (session stickiness) + BifrostContextKeyMCPExtraHeaders BifrostContextKey = "bifrost-mcp-extra-headers" // map[string][]string (these headers are forwarded only to the MCP while tool execution if they are in the allowlist of the MCP client) BifrostContextKeyMCPLogID BifrostContextKey = "bifrost-mcp-log-id" // string (unique UUID for each MCP tool log entry - set per goroutine by agent executor - DO NOT SET THIS MANUALLY) + BifrostContextKeyCompatConvertTextToChat BifrostContextKey = "bifrost-compat-convert-text-to-chat" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatConvertChatToResponses BifrostContextKey = "bifrost-compat-convert-chat-to-responses" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldDropParams BifrostContextKey = "bifrost-compat-should-drop-params" // bool (per-request override from x-bf-compat header) + BifrostContextKeyCompatShouldConvertParams BifrostContextKey = "bifrost-compat-should-convert-params" // bool (per-request override from x-bf-compat header) + BifrostContextKeyAttemptTrail BifrostContextKey = "bifrost-attempt-trail" // []KeyAttemptRecord (set by bifrost - DO NOT SET THIS MANUALLY) - per-attempt key selection history ) const ( @@ -264,6 +301,18 @@ const ( RoutingEngineLoadbalancing = "loadbalancing" ) +// KeyAttemptRecord captures the outcome of a single request attempt within executeRequestWithRetries. +// One record is appended per attempt regardless of whether the key changed between attempts. +// FailReason is supplementary retry metadata: it is populated only when another retry will be +// attempted (i.e. a non-terminal attempt), and is nil on any terminal attempt — including success, +// non-retryable failure, or a retryable error when no retries remain. +type KeyAttemptRecord struct { + Attempt int `json:"attempt"` + KeyID string `json:"key_id"` + KeyName string `json:"key_name"` + FailReason *string `json:"fail_reason,omitempty"` +} + // RoutingEngineLogEntry represents a log entry from a routing engine // format: [timestamp] [engine] - message type RoutingEngineLogEntry struct { @@ -272,6 +321,27 @@ type RoutingEngineLogEntry struct { Timestamp int64 // Unix milliseconds } +// PluginLogEntry represents a structured log entry emitted by a plugin via ctx.Log(). +type PluginLogEntry struct { + PluginName string `json:"plugin_name"` + Level LogLevel `json:"level"` + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` // Unix milliseconds +} + +// GroupPluginLogsByName groups a flat slice of plugin log entries by plugin name. +// Returns nil if the input is empty. +func GroupPluginLogsByName(logs []PluginLogEntry) map[string][]PluginLogEntry { + if len(logs) == 0 { + return nil + } + grouped := make(map[string][]PluginLogEntry, min(len(logs), 4)) + for _, entry := range logs { + grouped[entry.PluginName] = append(grouped[entry.PluginName], entry) + } + return grouped +} + // NOTE: for custom plugin implementation dealing with streaming short circuit, // make sure to mark BifrostContextKeyStreamEndIndicator as true at the end of the stream. @@ -542,6 +612,10 @@ func (br *BifrostRequest) SetModel(model string) { br.ImageVariationRequest.Model = model case br.VideoGenerationRequest != nil: br.VideoGenerationRequest.Model = model + case br.BatchCreateRequest != nil: + if br.BatchCreateRequest.Model != nil { + br.BatchCreateRequest.Model = new(model) + } } } @@ -784,6 +858,213 @@ func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { return &BifrostResponseExtraFields{} } +func (r *BifrostResponse) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if r == nil { + return + } + resolvedModel := resolvedModelUsed + if resolvedModel == "" { + resolvedModel = originalModelRequested + } + switch { + case r.ListModelsResponse != nil: + r.ListModelsResponse.ExtraFields.RequestType = requestType + r.ListModelsResponse.ExtraFields.Provider = provider + r.ListModelsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ListModelsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TextCompletionResponse != nil: + r.TextCompletionResponse.ExtraFields.RequestType = requestType + r.TextCompletionResponse.ExtraFields.Provider = provider + r.TextCompletionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TextCompletionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ChatResponse != nil: + r.ChatResponse.ExtraFields.RequestType = requestType + r.ChatResponse.ExtraFields.Provider = provider + r.ChatResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ChatResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesResponse != nil: + r.ResponsesResponse.ExtraFields.RequestType = requestType + r.ResponsesResponse.ExtraFields.Provider = provider + r.ResponsesResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ResponsesStreamResponse != nil: + r.ResponsesStreamResponse.ExtraFields.RequestType = requestType + r.ResponsesStreamResponse.ExtraFields.Provider = provider + r.ResponsesStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ResponsesStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.CountTokensResponse != nil: + r.CountTokensResponse.ExtraFields.RequestType = requestType + r.CountTokensResponse.ExtraFields.Provider = provider + r.CountTokensResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.CountTokensResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.EmbeddingResponse != nil: + r.EmbeddingResponse.ExtraFields.RequestType = requestType + r.EmbeddingResponse.ExtraFields.Provider = provider + r.EmbeddingResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.EmbeddingResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.RerankResponse != nil: + r.RerankResponse.ExtraFields.RequestType = requestType + r.RerankResponse.ExtraFields.Provider = provider + r.RerankResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.RerankResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechResponse != nil: + r.SpeechResponse.ExtraFields.RequestType = requestType + r.SpeechResponse.ExtraFields.Provider = provider + r.SpeechResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.SpeechStreamResponse != nil: + r.SpeechStreamResponse.ExtraFields.RequestType = requestType + r.SpeechStreamResponse.ExtraFields.Provider = provider + r.SpeechStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.SpeechStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionResponse != nil: + r.TranscriptionResponse.ExtraFields.RequestType = requestType + r.TranscriptionResponse.ExtraFields.Provider = provider + r.TranscriptionResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.TranscriptionStreamResponse != nil: + r.TranscriptionStreamResponse.ExtraFields.RequestType = requestType + r.TranscriptionStreamResponse.ExtraFields.Provider = provider + r.TranscriptionStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.TranscriptionStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationResponse != nil: + r.ImageGenerationResponse.ExtraFields.RequestType = requestType + r.ImageGenerationResponse.ExtraFields.Provider = provider + r.ImageGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ImageGenerationStreamResponse != nil: + r.ImageGenerationStreamResponse.ExtraFields.RequestType = requestType + r.ImageGenerationStreamResponse.ExtraFields.Provider = provider + r.ImageGenerationStreamResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ImageGenerationStreamResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoGenerationResponse != nil: + r.VideoGenerationResponse.ExtraFields.RequestType = requestType + r.VideoGenerationResponse.ExtraFields.Provider = provider + r.VideoGenerationResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoGenerationResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDownloadResponse != nil: + r.VideoDownloadResponse.ExtraFields.RequestType = requestType + r.VideoDownloadResponse.ExtraFields.Provider = provider + r.VideoDownloadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDownloadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoListResponse != nil: + r.VideoListResponse.ExtraFields.RequestType = requestType + r.VideoListResponse.ExtraFields.Provider = provider + r.VideoListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.VideoDeleteResponse != nil: + r.VideoDeleteResponse.ExtraFields.RequestType = requestType + r.VideoDeleteResponse.ExtraFields.Provider = provider + r.VideoDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.VideoDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileUploadResponse != nil: + r.FileUploadResponse.ExtraFields.RequestType = requestType + r.FileUploadResponse.ExtraFields.Provider = provider + r.FileUploadResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileUploadResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileListResponse != nil: + r.FileListResponse.ExtraFields.RequestType = requestType + r.FileListResponse.ExtraFields.Provider = provider + r.FileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileRetrieveResponse != nil: + r.FileRetrieveResponse.ExtraFields.RequestType = requestType + r.FileRetrieveResponse.ExtraFields.Provider = provider + r.FileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileDeleteResponse != nil: + r.FileDeleteResponse.ExtraFields.RequestType = requestType + r.FileDeleteResponse.ExtraFields.Provider = provider + r.FileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.FileContentResponse != nil: + r.FileContentResponse.ExtraFields.RequestType = requestType + r.FileContentResponse.ExtraFields.Provider = provider + r.FileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.FileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCreateResponse != nil: + r.BatchCreateResponse.ExtraFields.RequestType = requestType + r.BatchCreateResponse.ExtraFields.Provider = provider + r.BatchCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchListResponse != nil: + r.BatchListResponse.ExtraFields.RequestType = requestType + r.BatchListResponse.ExtraFields.Provider = provider + r.BatchListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchRetrieveResponse != nil: + r.BatchRetrieveResponse.ExtraFields.RequestType = requestType + r.BatchRetrieveResponse.ExtraFields.Provider = provider + r.BatchRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchCancelResponse != nil: + r.BatchCancelResponse.ExtraFields.RequestType = requestType + r.BatchCancelResponse.ExtraFields.Provider = provider + r.BatchCancelResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchCancelResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchDeleteResponse != nil: + r.BatchDeleteResponse.ExtraFields.RequestType = requestType + r.BatchDeleteResponse.ExtraFields.Provider = provider + r.BatchDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.BatchResultsResponse != nil: + r.BatchResultsResponse.ExtraFields.RequestType = requestType + r.BatchResultsResponse.ExtraFields.Provider = provider + r.BatchResultsResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.BatchResultsResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerCreateResponse != nil: + r.ContainerCreateResponse.ExtraFields.RequestType = requestType + r.ContainerCreateResponse.ExtraFields.Provider = provider + r.ContainerCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerListResponse != nil: + r.ContainerListResponse.ExtraFields.RequestType = requestType + r.ContainerListResponse.ExtraFields.Provider = provider + r.ContainerListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerRetrieveResponse != nil: + r.ContainerRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerRetrieveResponse.ExtraFields.Provider = provider + r.ContainerRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerDeleteResponse != nil: + r.ContainerDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerDeleteResponse.ExtraFields.Provider = provider + r.ContainerDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileCreateResponse != nil: + r.ContainerFileCreateResponse.ExtraFields.RequestType = requestType + r.ContainerFileCreateResponse.ExtraFields.Provider = provider + r.ContainerFileCreateResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileCreateResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileListResponse != nil: + r.ContainerFileListResponse.ExtraFields.RequestType = requestType + r.ContainerFileListResponse.ExtraFields.Provider = provider + r.ContainerFileListResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileListResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileRetrieveResponse != nil: + r.ContainerFileRetrieveResponse.ExtraFields.RequestType = requestType + r.ContainerFileRetrieveResponse.ExtraFields.Provider = provider + r.ContainerFileRetrieveResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileRetrieveResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileContentResponse != nil: + r.ContainerFileContentResponse.ExtraFields.RequestType = requestType + r.ContainerFileContentResponse.ExtraFields.Provider = provider + r.ContainerFileContentResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileContentResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.ContainerFileDeleteResponse != nil: + r.ContainerFileDeleteResponse.ExtraFields.RequestType = requestType + r.ContainerFileDeleteResponse.ExtraFields.Provider = provider + r.ContainerFileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.ContainerFileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel + case r.PassthroughResponse != nil: + r.PassthroughResponse.ExtraFields.RequestType = requestType + r.PassthroughResponse.ExtraFields.Provider = provider + r.PassthroughResponse.ExtraFields.OriginalModelRequested = originalModelRequested + r.PassthroughResponse.ExtraFields.ResolvedModelUsed = resolvedModel + } +} + // BifrostMCPResponse is the response struct for all MCP responses. // only ONE of the following fields should be set: // - ChatMessage @@ -796,18 +1077,19 @@ type BifrostMCPResponse struct { // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - RequestType RequestType `json:"request_type"` - Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - ModelDeployment string `json:"model_deployment,omitempty"` // only present for providers which use model deployments (e.g. Azure, Bedrock) - Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` - ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) + RequestType RequestType `json:"request_type"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` // the model alias the caller sent in the request + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` // the actual provider API identifier used (equals OriginalModelRequested when no alias mapping exists) + Latency int64 `json:"latency"` // in milliseconds (for streaming responses this will be each chunk latency, and the last chunk latency will be the total latency) + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` + ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` // params dropped by the compat plugin based on model catalog + ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) } type BifrostMCPResponseExtraFields struct { @@ -895,6 +1177,20 @@ type BifrostError struct { ExtraFields BifrostErrorExtraFields `json:"extra_fields"` } +func (e *BifrostError) PopulateExtraFields(requestType RequestType, provider ModelProvider, originalModelRequested string, resolvedModelUsed string) { + if e == nil { + return + } + e.ExtraFields.RequestType = requestType + e.ExtraFields.Provider = provider + e.ExtraFields.OriginalModelRequested = originalModelRequested + if resolvedModelUsed != "" { + e.ExtraFields.ResolvedModelUsed = resolvedModelUsed + } else { + e.ExtraFields.ResolvedModelUsed = originalModelRequested + } +} + // StreamControl represents stream control options. type StreamControl struct { LogError *bool `json:"log_error,omitempty"` // Optional: Controls logging of error @@ -968,11 +1264,14 @@ func (e *ErrorField) UnmarshalJSON(data []byte) error { // BifrostErrorExtraFields contains additional fields in an error response. type BifrostErrorExtraFields struct { - Provider ModelProvider `json:"provider,omitempty"` - ModelRequested string `json:"model_requested,omitempty"` - RequestType RequestType `json:"request_type,omitempty"` - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` - LiteLLMCompat bool `json:"litellm_compat,omitempty"` - KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + Provider ModelProvider `json:"provider,omitempty"` + OriginalModelRequested string `json:"original_model_requested,omitempty"` + ResolvedModelUsed string `json:"resolved_model_used,omitempty"` + RequestType RequestType `json:"request_type,omitempty"` + RawRequest interface{} `json:"raw_request,omitempty"` + RawResponse interface{} `json:"raw_response,omitempty"` + ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` + DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` + KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` + MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication } diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go index 51971243f0..320fa15ac1 100644 --- a/core/schemas/chatcompletions.go +++ b/core/schemas/chatcompletions.go @@ -2,6 +2,7 @@ package schemas import ( "bytes" + "encoding/json" "fmt" "time" ) @@ -80,7 +81,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -113,7 +115,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -149,7 +152,8 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion RequestType: TextCompletionRequest, ChunkIndex: cr.ExtraFields.ChunkIndex, Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, Latency: cr.ExtraFields.Latency, RawResponse: cr.ExtraFields.RawResponse, CacheDebug: cr.ExtraFields.CacheDebug, @@ -166,13 +170,15 @@ func (cr *BifrostChatResponse) ToTextCompletionResponse() *BifrostTextCompletion SystemFingerprint: cr.SystemFingerprint, Usage: cr.Usage, ExtraFields: BifrostResponseExtraFields{ - RequestType: TextCompletionRequest, - ChunkIndex: cr.ExtraFields.ChunkIndex, - Provider: cr.ExtraFields.Provider, - ModelRequested: cr.ExtraFields.ModelRequested, - Latency: cr.ExtraFields.Latency, - RawResponse: cr.ExtraFields.RawResponse, - CacheDebug: cr.ExtraFields.CacheDebug, + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + ResolvedModelUsed: cr.ExtraFields.ResolvedModelUsed, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, }, } } @@ -186,6 +192,7 @@ type ChatParameters struct { MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // Maximum number of tokens to generate Metadata *map[string]any `json:"metadata,omitempty"` // Metadata to be returned with the response Modalities []string `json:"modalities,omitempty"` // Modalities to be returned with the response + N *int `json:"n,omitempty"` // Number of chat completions to generate when supported ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` Prediction *ChatPrediction `json:"prediction,omitempty"` // Predicted output content (OpenAI only) PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens @@ -208,6 +215,19 @@ type ChatParameters struct { Verbosity *string `json:"verbosity,omitempty"` // "low" | "medium" | "high" WebSearchOptions *ChatWebSearchOptions `json:"web_search_options,omitempty"` // Web search options (OpenAI only) + // Anthropic-native knobs promoted to the neutral layer. These pass through + // typed to Anthropic-family providers (honored/stripped per ProviderFeatures + // in core/providers/anthropic/types.go). Non-Anthropic providers (OpenAI + // etc.) silently ignore them. + TopK *int `json:"top_k,omitempty"` // Anthropic top_k sampling + Speed *string `json:"speed,omitempty"` // "fast" (Anthropic fast-mode-2026-02-01 beta, Opus 4.6 only) + InferenceGeo *string `json:"inference_geo,omitempty"` // Anthropic inference_geo (Claude API only) + MCPServers []ChatMCPServer `json:"mcp_servers,omitempty"` // Anthropic MCP connector (mcp-client-2025-11-20) + Container *ChatContainer `json:"container,omitempty"` // Anthropic container (string id, or object with skills[] — beta skills-2025-10-02) + CacheControl *CacheControl `json:"cache_control,omitempty"` // Top-level request cache control (Anthropic family) + TaskBudget *ChatTaskBudget `json:"task_budget,omitempty"` // Anthropic output_config.task_budget (task-budgets-2026-03-13 beta) + ContextManagement json.RawMessage `json:"context_management,omitempty"` // Anthropic context_management — complex union, passed as raw JSON to the provider layer + // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. ExtraParams map[string]interface{} `json:"-"` @@ -269,6 +289,7 @@ type ChatReasoning struct { Enabled *bool `json:"enabled,omitempty"` // Explicitly enable or disable reasoning (required by OpenRouter to disable reasoning for some models) Effort *string `json:"effort,omitempty"` // "none" | "minimal" | "low" | "medium" | "high" (any value other than "none" will enable reasoning) MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate for the reasoning output (required for anthropic) + Display *string `json:"display,omitempty"` // Anthropic thinking.display: "summarized" | "omitted" (requires model support for adaptive thinking) } // ChatPrediction represents predicted output content for the model to reference (OpenAI only). @@ -313,12 +334,179 @@ const ( ChatToolTypeCustom ChatToolType = "custom" ) +type MCPToolAnnotations struct { + Title string `json:"title,omitempty"` // Human-readable title for the tool + ReadOnlyHint *bool `json:"readOnlyHint,omitempty"` // If true, the tool does not modify its environment + DestructiveHint *bool `json:"destructiveHint,omitempty"` // If true, the tool may perform destructive updates + IdempotentHint *bool `json:"idempotentHint,omitempty"` // If true, repeated calls with same args have no additional effect + OpenWorldHint *bool `json:"openWorldHint,omitempty"` // If true, the tool interacts with external entities +} + // ChatTool represents a tool definition. +// +// Three shapes coexist under this type: +// 1. OpenAI function tool: Type="function", Function non-nil. +// 2. Custom tool: Type="custom", Custom non-nil. +// 3. Anthropic server tool: Type=server-tool version string (e.g. +// "web_search_20260209", "computer_20251124", "mcp_toolset"), Function/Custom +// nil, Name populated at top level, and the variant-specific fields +// (MaxUses, DisplayWidthPx, etc.) populated inline. +// +// JSON shape for (3) matches Anthropic's native tool format directly +// (e.g. {"type":"web_search_20260209","name":"web_search","max_uses":5}). +// +// Custom MarshalJSON/UnmarshalJSON enforce the union invariant: +// - On marshal, fields that don't match Type are cleared on a copy so the +// wire format always carries exactly one variant. Mixed caller state +// (e.g. Type="web_search_20260209" with Function also set) gets +// canonicalized instead of being forwarded ambiguously to providers. +// - On unmarshal, tolerantly accept whatever JSON shape comes in, then +// normalize the decoded struct so downstream code sees a canonical shape. type ChatTool struct { - Type ChatToolType `json:"type"` - Function *ChatToolFunction `json:"function,omitempty"` // Function definition - Custom *ChatToolCustom `json:"custom,omitempty"` // Custom tool definition - CacheControl *CacheControl `json:"cache_control,omitempty"` // Cache control for the tool + Type ChatToolType `json:"type"` + Function *ChatToolFunction `json:"function,omitempty"` // Function definition (shape 1) + Custom *ChatToolCustom `json:"custom,omitempty"` // Custom tool definition (shape 2) + CacheControl *CacheControl `json:"cache_control,omitempty"` // Cache control for the tool + Annotations *MCPToolAnnotations `json:"-"` // MCP tool annotations (Bifrost-internal, never forwarded to providers) + + // Anthropic-native tool flags promoted to the neutral layer. All optional; + // ignored by providers that don't support them. Gating per ProviderFeatures + // in core/providers/anthropic/types.go. + DeferLoading *bool `json:"defer_loading,omitempty"` // Anthropic advanced-tool-use: defer loading of tool definition + AllowedCallers []string `json:"allowed_callers,omitempty"` // Anthropic advanced-tool-use: which callers can invoke this tool ("direct", "code_execution_20250825", "code_execution_20260120") + InputExamples []ChatToolInputExample `json:"input_examples,omitempty"` // Anthropic tool-examples-2025-10-29: example inputs for the tool + EagerInputStreaming *bool `json:"eager_input_streaming,omitempty"` // Anthropic fine-grained-tool-streaming-2025-05-14: stream input_json_delta before full args are determined (custom tools only) + + // Anthropic server-tool fields (shape 3). All optional; only populated when + // Type is a server-tool version string. Function tools carry their name + // inside Function.Name — use omitempty here so Name doesn't double-emit. + Name string `json:"name,omitempty"` + + // web_search_* and web_fetch_*: + MaxUses *int `json:"max_uses,omitempty"` + AllowedDomains []string `json:"allowed_domains,omitempty"` + BlockedDomains []string `json:"blocked_domains,omitempty"` + UserLocation *ChatToolUserLocation `json:"user_location,omitempty"` + + // web_fetch_* only: + MaxContentTokens *int `json:"max_content_tokens,omitempty"` + Citations *ChatToolCitationsConfig `json:"citations,omitempty"` + UseCache *bool `json:"use_cache,omitempty"` // web_fetch_20260309+ only + + // computer_*: + DisplayWidthPx *int `json:"display_width_px,omitempty"` + DisplayHeightPx *int `json:"display_height_px,omitempty"` + DisplayNumber *int `json:"display_number,omitempty"` + EnableZoom *bool `json:"enable_zoom,omitempty"` // computer_20251124 only + + // text_editor_20250728+: + MaxCharacters *int `json:"max_characters,omitempty"` + + // mcp_toolset: + MCPServerName string `json:"mcp_server_name,omitempty"` + DefaultConfig *ChatMCPToolsetConfig `json:"default_config,omitempty"` + Configs map[string]*ChatMCPToolsetConfig `json:"configs,omitempty"` +} + +// normalizeShape clears fields that don't belong to the ChatTool's active +// variant, encoding the three-way union invariant: +// +// 1. Type="function": keep Function; nil Custom, server-tool Name, and +// variant metadata (function tools carry their name inside Function.Name). +// 2. Type="custom": keep Custom and top-level Name; nil Function and +// server-tool variant metadata. +// 3. Any other Type: server-tool variant — keep Name and variant fields; +// nil Function and Custom. +// +// Called by both Marshal (strict wire format) and Unmarshal (canonicalize +// after tolerant decode of potentially mixed input). +func (t *ChatTool) normalizeShape() { + switch t.Type { + case ChatToolTypeFunction: + t.Custom = nil + t.Name = "" + t.clearServerToolVariantFields() + case ChatToolTypeCustom: + t.Function = nil + t.clearServerToolVariantFields() + default: + t.Function = nil + t.Custom = nil + } +} + +func (t *ChatTool) clearServerToolVariantFields() { + t.MaxUses = nil + t.AllowedDomains = nil + t.BlockedDomains = nil + t.UserLocation = nil + t.MaxContentTokens = nil + t.Citations = nil + t.UseCache = nil + t.DisplayWidthPx = nil + t.DisplayHeightPx = nil + t.DisplayNumber = nil + t.EnableZoom = nil + t.MaxCharacters = nil + t.MCPServerName = "" + t.DefaultConfig = nil + t.Configs = nil +} + +// MarshalJSON enforces the ChatTool union invariant: exactly one variant's +// fields are emitted on the wire, matching Type. A mix-state tool +// (e.g. Type="web_search_20260209" with Function also populated) would +// otherwise serialize both, and downstream provider converters — which +// dispatch on the top-level Type/Name shape — could misinterpret or +// silently forward the stray fields. +func (t ChatTool) MarshalJSON() ([]byte, error) { + normalized := t + normalized.normalizeShape() + type Alias ChatTool + return MarshalSorted((*Alias)(&normalized)) +} + +// UnmarshalJSON tolerantly decodes whatever JSON shape arrives, then +// canonicalizes the struct via normalizeShape so downstream code sees a +// single-variant result even if the input mixed multiple variants. +// Resets the receiver before decoding so omitted optional fields from a +// prior payload don't survive the new decode; mirrors ChatContainer.UnmarshalJSON. +func (t *ChatTool) UnmarshalJSON(data []byte) error { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + *t = ChatTool{} + return nil + } + + type Alias ChatTool + var temp Alias + if err := Unmarshal(data, &temp); err != nil { + return err + } + *t = ChatTool(temp) + t.normalizeShape() + return nil +} + +// ChatToolUserLocation is the neutral user_location for web_search tools. +type ChatToolUserLocation struct { + Type *string `json:"type,omitempty"` // "approximate" + City *string `json:"city,omitempty"` + Region *string `json:"region,omitempty"` + Country *string `json:"country,omitempty"` + Timezone *string `json:"timezone,omitempty"` +} + +// ChatToolCitationsConfig is the request-side citations config on web_fetch +// ({"enabled": true/false}). Distinct from response-side text citations. +type ChatToolCitationsConfig struct { + Enabled *bool `json:"enabled,omitempty"` +} + +// ChatMCPToolsetConfig configures an MCP toolset entry (mcp_toolset tool). +type ChatMCPToolsetConfig struct { + Enabled *bool `json:"enabled,omitempty"` + DeferLoading *bool `json:"defer_loading,omitempty"` } // ChatToolFunction represents a function definition. @@ -543,7 +731,6 @@ type AdditionalPropertiesStruct struct { // MarshalJSON implements custom JSON marshalling for AdditionalPropertiesStruct. // It marshals either AdditionalPropertiesBool or AdditionalPropertiesMap based on which is set. func (a AdditionalPropertiesStruct) MarshalJSON() ([]byte, error) { - // if both are set, return an error if a.AdditionalPropertiesBool != nil && a.AdditionalPropertiesMap != nil { return nil, fmt.Errorf("both AdditionalPropertiesBool and AdditionalPropertiesMap are set; only one should be non-nil") @@ -946,6 +1133,103 @@ type CacheControl struct { Scope *string `json:"scope,omitempty"` // "user" | "global" } +// --------------------------------------------------------------------------- +// Neutral mirror types for Anthropic-native knobs promoted onto ChatParameters +// --------------------------------------------------------------------------- +// These live in schemas/ (not provider-specific) so ChatParameters stays +// import-free of provider packages. The anthropic provider reads them in +// ToAnthropicChatRequest and maps them to AnthropicMessageRequest fields. + +// ChatContainerSkill describes one skill attached to a container. +// Origin: Anthropic container.skills[] (beta skills-2025-10-02). +type ChatContainerSkill struct { + SkillID string `json:"skill_id"` + Type string `json:"type"` // "anthropic" | "custom" + Version *string `json:"version,omitempty"` // Optional version pin +} + +// ChatContainerObject is the object form of ChatContainer. +// Both fields are optional — ID alone is a bare container reference; +// adding Skills makes it beta-gated. +type ChatContainerObject struct { + ID *string `json:"id,omitempty"` + Skills []ChatContainerSkill `json:"skills,omitempty"` +} + +// ChatContainer is the union "container" field on a chat request. +// Anthropic's API accepts either a plain string (container id) or an object +// with id + skills[]. Mirrors AnthropicContainer in the provider package. +type ChatContainer struct { + ContainerStr *string + ContainerObject *ChatContainerObject +} + +// MarshalJSON emits the raw string or the object form directly. +func (c ChatContainer) MarshalJSON() ([]byte, error) { + if c.ContainerStr != nil && c.ContainerObject != nil { + return nil, fmt.Errorf("both ContainerStr and ContainerObject are set; only one should be non-nil") + } + if c.ContainerStr != nil { + return MarshalSorted(*c.ContainerStr) + } + if c.ContainerObject != nil { + return MarshalSorted(c.ContainerObject) + } + return MarshalSorted(nil) +} + +// UnmarshalJSON accepts either a plain string or the object form. +// Uses the build-tag-aware package-level Unmarshal (sonic on native, stdlib +// json on wasm/tinygo) and clears the inactive union arm on each success so +// repeated decodes into the same value don't leave both arms populated. +// JSON null clears both arms. Follows the ChatToolChoice.UnmarshalJSON pattern. +func (c *ChatContainer) UnmarshalJSON(data []byte) error { + trimmed := bytes.TrimSpace(data) + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + c.ContainerStr = nil + c.ContainerObject = nil + return nil + } + + var s string + if err := Unmarshal(data, &s); err == nil { + c.ContainerStr = &s + c.ContainerObject = nil + return nil + } + var obj ChatContainerObject + if err := Unmarshal(data, &obj); err == nil { + c.ContainerStr = nil + c.ContainerObject = &obj + return nil + } + return fmt.Errorf("container field is neither a string nor an object") +} + +// ChatTaskBudget advises the model of a full-loop token budget. +// Origin: Anthropic output_config.task_budget (beta task-budgets-2026-03-13). +type ChatTaskBudget struct { + Type string `json:"type"` // Always "tokens" + Total int `json:"total"` // Total advisory budget + Remaining *int `json:"remaining,omitempty"` // Optional client-side counter +} + +// ChatToolInputExample is one example input for a tool, shown to the model. +// Origin: Anthropic tool.input_examples (beta tool-examples-2025-10-29). +type ChatToolInputExample struct { + Input json.RawMessage `json:"input"` + Description *string `json:"description,omitempty"` +} + +// ChatMCPServer is an MCP server definition attached to a chat request. +// Origin: Anthropic mcp_servers[] (mcp-client-2025-11-20 format). +type ChatMCPServer struct { + Type string `json:"type"` // "url" + URL string `json:"url"` + Name string `json:"name"` + AuthorizationToken *string `json:"authorization_token,omitempty"` +} + // ChatInputImage represents image data in a message. type ChatInputImage struct { URL string `json:"url"` @@ -1210,7 +1494,7 @@ type BifrostLLMUsage struct { CompletionTokens int `json:"completion_tokens,omitempty"` CompletionTokensDetails *ChatCompletionTokensDetails `json:"completion_tokens_details,omitempty"` TotalTokens int `json:"total_tokens"` - Cost *BifrostCost `json:"cost,omitempty"` //Only for the providers which support cost calculation + Cost *BifrostCost `json:"cost,omitempty"` // Only for the providers which support cost calculation } type ChatPromptTokensDetails struct { diff --git a/core/schemas/context.go b/core/schemas/context.go index 1ff4663eae..68ac7c435e 100644 --- a/core/schemas/context.go +++ b/core/schemas/context.go @@ -24,6 +24,29 @@ var reservedKeys = []any{ BifrostContextKeySkipKeySelection, BifrostContextKeyURLPath, BifrostContextKeyDeferTraceCompletion, + BifrostContextKeyAttemptTrail, +} + +// pluginLogStore holds plugin log entries accumulated during request processing. +// It is shared between the root BifrostContext and all scoped contexts derived from it. +// Uses a flat slice (not map) to minimize heap allocations. +type pluginLogStore struct { + mu sync.Mutex + logs []PluginLogEntry +} + +// pluginLogStorePool pools pluginLogStore instances to reduce per-request allocations. +var pluginLogStorePool = sync.Pool{ + New: func() any { + return &pluginLogStore{logs: make([]PluginLogEntry, 0, 8)} + }, +} + +// pluginScopePool pools BifrostContext instances used as scoped plugin contexts. +var pluginScopePool = sync.Pool{ + New: func() any { + return &BifrostContext{} + }, } // BifrostContext is a custom context.Context implementation that tracks user-set values. @@ -40,6 +63,11 @@ type BifrostContext struct { userValues map[any]any valuesMu sync.RWMutex blockRestrictedWrites atomic.Bool + + // Plugin scoping fields + pluginScope *string // Non-nil when this is a scoped plugin context + pluginLogs atomic.Pointer[pluginLogStore] // Shared log store; lazily initialized on root, shared by scoped contexts + valueDelegate *BifrostContext // For scoped contexts: delegate Value/SetValue to this root context } // NewBifrostContext creates a new BifrostContext with the given parent context and deadline. @@ -166,8 +194,12 @@ func (bc *BifrostContext) cancel(err error) { } // Deadline returns the deadline for this context. +// For scoped contexts, delegates to the root context. // If both this context and the parent have deadlines, the earlier one is returned. func (bc *BifrostContext) Deadline() (time.Time, bool) { + if bc.valueDelegate != nil { + return bc.valueDelegate.Deadline() + } parentDeadline, parentHasDeadline := bc.parent.Deadline() if !bc.hasDeadline && !parentHasDeadline { @@ -195,16 +227,24 @@ func (bc *BifrostContext) Done() <-chan struct{} { } // Err returns the error explaining why the context was cancelled. +// For scoped contexts, delegates to the root context. // Returns nil if the context has not been cancelled. func (bc *BifrostContext) Err() error { + if bc.valueDelegate != nil { + return bc.valueDelegate.Err() + } bc.errMu.RLock() defer bc.errMu.RUnlock() return bc.err } // Value returns the value associated with the key. -// It first checks the internal userValues map, then delegates to the parent context. +// For scoped contexts, delegates to the root context via valueDelegate. +// Otherwise checks the internal userValues map, then delegates to the parent context. func (bc *BifrostContext) Value(key any) any { + if bc.valueDelegate != nil { + return bc.valueDelegate.Value(key) + } bc.valuesMu.RLock() if val, ok := bc.userValues[key]; ok { bc.valuesMu.RUnlock() @@ -212,12 +252,21 @@ func (bc *BifrostContext) Value(key any) any { } bc.valuesMu.RUnlock() + if bc.parent == nil { + return nil + } + return bc.parent.Value(key) } // SetValue sets a value in the internal userValues map. +// For scoped contexts, delegates to the root context via valueDelegate. // This is thread-safe and can be called concurrently. func (bc *BifrostContext) SetValue(key, value any) { + if bc.valueDelegate != nil { + bc.valueDelegate.SetValue(key, value) + return + } // Check if the key is a reserved key if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { // we silently drop writes for these reserved keys @@ -232,7 +281,12 @@ func (bc *BifrostContext) SetValue(key, value any) { } // ClearValue clears a value from the internal userValues map. +// For scoped contexts, delegates to the root context via valueDelegate. func (bc *BifrostContext) ClearValue(key any) { + if bc.valueDelegate != nil { + bc.valueDelegate.ClearValue(key) + return + } // Check if the key is a reserved key if bc.blockRestrictedWrites.Load() && slices.Contains(reservedKeys, key) { // we silently drop writes for these reserved keys @@ -245,8 +299,12 @@ func (bc *BifrostContext) ClearValue(key any) { } } -// GetAndSetValue gets a value from the internal userValues map and sets it +// GetAndSetValue gets a value from the internal userValues map and sets it. +// For scoped contexts, delegates to the root context via valueDelegate. func (bc *BifrostContext) GetAndSetValue(key any, value any) any { + if bc.valueDelegate != nil { + return bc.valueDelegate.GetAndSetValue(key, value) + } bc.valuesMu.Lock() defer bc.valuesMu.Unlock() // Check if the key is a reserved key @@ -340,3 +398,104 @@ func AppendToContextList[T any](ctx *BifrostContext, key BifrostContextKey, valu } ctx.SetValue(key, append(existingValues, value)) } + +// WithPluginScope returns a lightweight scoped BifrostContext from the pool. +// The scoped context shares the root's pluginLogs store and delegates all +// Value/SetValue operations to the root context. +// Call ReleasePluginScope() when done to return the scoped context to the pool. +func (bc *BifrostContext) WithPluginScope(name *string) *BifrostContext { + // Lazily initialize the plugin log store on the root context (CAS to avoid race) + if bc.pluginLogs.Load() == nil { + newStore := pluginLogStorePool.Get().(*pluginLogStore) + if !bc.pluginLogs.CompareAndSwap(nil, newStore) { + // Another goroutine initialized first — return unused store to pool + pluginLogStorePool.Put(newStore) + } + } + + scoped := pluginScopePool.Get().(*BifrostContext) + scoped.parent = bc.parent + scoped.done = bc.done + scoped.pluginScope = name + scoped.pluginLogs.Store(bc.pluginLogs.Load()) + scoped.valueDelegate = bc + return scoped +} + +// ReleasePluginScope returns a scoped context to the pool. +// Safe no-op if called on a non-scoped context. +// Do not use the scoped context after calling this method. +func (bc *BifrostContext) ReleasePluginScope() { + if bc.valueDelegate == nil { + return // not a scoped context + } + bc.parent = nil + bc.done = nil + bc.pluginScope = nil + bc.pluginLogs.Store(nil) + bc.valueDelegate = nil + pluginScopePool.Put(bc) +} + +// Log appends a structured log entry for the current plugin scope. +// No-op if the context is not scoped to a plugin or has no log store. +func (bc *BifrostContext) Log(level LogLevel, msg string) { + store := bc.pluginLogs.Load() + if bc.pluginScope == nil || store == nil { + return + } + store.mu.Lock() + store.logs = append(store.logs, PluginLogEntry{ + PluginName: *bc.pluginScope, + Level: level, + Message: msg, + Timestamp: time.Now().UnixMilli(), + }) + store.mu.Unlock() +} + +// GetPluginLogs returns a deep copy of all accumulated plugin log entries. +// Thread-safe. Returns nil if no logs have been recorded. +func (bc *BifrostContext) GetPluginLogs() []PluginLogEntry { + store := bc.pluginLogs.Load() + if store == nil { + return nil + } + store.mu.Lock() + defer store.mu.Unlock() + if len(store.logs) == 0 { + return nil + } + copied := make([]PluginLogEntry, len(store.logs)) + copy(copied, store.logs) + return copied +} + +// DrainPluginLogs transfers ownership of the plugin log slice to the caller. +// The internal log store is returned to the pool after draining. +// Returns nil if no logs have been recorded. +// This should be called once on the root context after all plugin hooks have completed. +func (bc *BifrostContext) DrainPluginLogs() []PluginLogEntry { + if bc.valueDelegate != nil { + return nil // scoped contexts must not drain the shared log store + } + store := bc.pluginLogs.Load() + if store == nil { + return nil + } + bc.pluginLogs.Store(nil) + + store.mu.Lock() + logs := store.logs + // Reset with fresh pre-allocated slice before returning to pool + store.logs = make([]PluginLogEntry, 0, 8) + store.mu.Unlock() + + // Return the store to the pool for reuse + pluginLogStorePool.Put(store) + + if len(logs) == 0 { + return nil + } + return logs +} diff --git a/core/schemas/context_test.go b/core/schemas/context_test.go index 75e52e2061..108da2ced0 100644 --- a/core/schemas/context_test.go +++ b/core/schemas/context_test.go @@ -207,3 +207,125 @@ func TestNewBifrostContext_NilParent(t *testing.T) { t.Errorf("Cancelled context should have Canceled error, got %v", ctx.Err()) } } + +// Plugin logging tests + +func TestPluginLog_NoScopeIsNoop(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + ctx.Log(LogLevelInfo, "should be ignored") + logs := ctx.GetPluginLogs() + if logs != nil { + t.Errorf("expected nil logs without plugin scope, got %v", logs) + } +} + +func TestPluginLog_SinglePlugin(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + name := "test-plugin" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "hello") + scoped.Log(LogLevelError, "oops") + scoped.ReleasePluginScope() + + logs := ctx.GetPluginLogs() + if len(logs) != 2 { + t.Fatalf("expected 2 logs, got %d", len(logs)) + } + if logs[0].PluginName != "test-plugin" || logs[0].Level != LogLevelInfo || logs[0].Message != "hello" { + t.Errorf("unexpected first log: %+v", logs[0]) + } + if logs[1].Level != LogLevelError || logs[1].Message != "oops" { + t.Errorf("unexpected second log: %+v", logs[1]) + } +} + +func TestPluginLog_MultiplePlugins(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + + name1 := "plugin-a" + s1 := ctx.WithPluginScope(&name1) + s1.Log(LogLevelDebug, "a-msg") + s1.ReleasePluginScope() + + name2 := "plugin-b" + s2 := ctx.WithPluginScope(&name2) + s2.Log(LogLevelWarn, "b-msg") + s2.ReleasePluginScope() + + logs := ctx.GetPluginLogs() + if len(logs) != 2 { + t.Fatalf("expected 2 logs, got %d", len(logs)) + } + if logs[0].PluginName != "plugin-a" { + t.Errorf("expected plugin-a, got %s", logs[0].PluginName) + } + if logs[1].PluginName != "plugin-b" { + t.Errorf("expected plugin-b, got %s", logs[1].PluginName) + } +} + +func TestPluginLog_DrainTransfersOwnership(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + name := "drain-test" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "msg1") + scoped.ReleasePluginScope() + + drained := ctx.DrainPluginLogs() + if len(drained) != 1 { + t.Fatalf("expected 1 drained log, got %d", len(drained)) + } + + // After drain, GetPluginLogs should return nil + after := ctx.GetPluginLogs() + if after != nil { + t.Errorf("expected nil after drain, got %v", after) + } + + // Second drain should return nil + second := ctx.DrainPluginLogs() + if second != nil { + t.Errorf("expected nil on second drain, got %v", second) + } +} + +func TestPluginLog_ScopedContextValueDelegation(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + ctx.SetValue(BifrostContextKeyTraceID, "trace-123") + + name := "delegate-test" + scoped := ctx.WithPluginScope(&name) + + // Scoped should read from root + val := scoped.Value(BifrostContextKeyTraceID) + if val != "trace-123" { + t.Errorf("expected trace-123, got %v", val) + } + + // Scoped should write to root + type testContextKey string + const customKey testContextKey = "custom-key" + scoped.SetValue(customKey, "custom-val") + if ctx.Value(customKey) != "custom-val" { + t.Errorf("SetValue on scoped did not delegate to root") + } + + scoped.ReleasePluginScope() +} + +func TestPluginLog_PoolReuse(t *testing.T) { + ctx := NewBifrostContext(context.Background(), NoDeadline) + + // Create and release multiple scoped contexts to exercise the pool + for i := 0; i < 100; i++ { + name := "pool-test" + scoped := ctx.WithPluginScope(&name) + scoped.Log(LogLevelInfo, "pooled") + scoped.ReleasePluginScope() + } + + logs := ctx.DrainPluginLogs() + if len(logs) != 100 { + t.Errorf("expected 100 logs from pool reuse, got %d", len(logs)) + } +} diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go index 1d8890dd1f..9ca2fb2cdf 100644 --- a/core/schemas/embedding.go +++ b/core/schemas/embedding.go @@ -116,14 +116,16 @@ type EmbeddingParameters struct { type EmbeddingData struct { Index int `json:"index"` Object string `json:"object"` // "embedding" - Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64 or [][]float64 + Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32 } type EmbeddingStruct struct { // Embedding responses preserve provider precision in normalized API output. - EmbeddingStr *string - EmbeddingArray []float64 - Embedding2DArray [][]float64 + EmbeddingStr *string + EmbeddingArray []float64 + Embedding2DArray [][]float64 + EmbeddingInt8Array []int8 // for int8 / binary formats + EmbeddingInt32Array []int32 // for uint8 / ubinary formats } func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { @@ -136,6 +138,12 @@ func (be EmbeddingStruct) MarshalJSON() ([]byte, error) { if be.Embedding2DArray != nil { return MarshalSorted(be.Embedding2DArray) } + if be.EmbeddingInt8Array != nil { + return Marshal(be.EmbeddingInt8Array) + } + if be.EmbeddingInt32Array != nil { + return Marshal(be.EmbeddingInt32Array) + } return nil, fmt.Errorf("no embedding found") } @@ -161,5 +169,19 @@ func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error { return nil } - return fmt.Errorf("embedding field is neither a string nor an array of float64 nor a 2D array of float64") + // Try to unmarshal as a direct array of int8 + var int8Content []int8 + if err := Unmarshal(data, &int8Content); err == nil { + be.EmbeddingInt8Array = int8Content + return nil + } + + // Try to unmarshal as a direct array of int32 + var int32Content []int32 + if err := Unmarshal(data, &int32Content); err == nil { + be.EmbeddingInt32Array = int32Content + return nil + } + + return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32") } diff --git a/core/schemas/envvar.go b/core/schemas/envvar.go index 6c5f996f08..c8fe249699 100644 --- a/core/schemas/envvar.go +++ b/core/schemas/envvar.go @@ -117,6 +117,9 @@ func (e *EnvVar) Equals(other *EnvVar) bool { // Redacted returns a new SecretKey with the value redacted. func (e *EnvVar) Redacted() *EnvVar { + if e == nil { + return nil + } if e.Val == "" { return &EnvVar{ Val: "", @@ -144,6 +147,34 @@ func (e *EnvVar) Redacted() *EnvVar { } } +// MarshalJSON serializes the EnvVar to JSON. +// SECURITY: When the value was sourced from an environment variable, the resolved +// value is automatically redacted before being serialized. This ensures that secrets +// injected via env vars are never leaked through any JSON API response, regardless +// of whether the surrounding code remembered to call Redacted() explicitly. +// +// Plain (non-env) values are still emitted as-is — callers that want to mask those +// must continue using Redacted() at the field level (this matches the existing +// per-provider redaction logic). +// +// This does NOT affect: +// - GORM persistence (uses the Value() driver method, not JSON) +// - Encryption (operates on the Val field directly) +// - Internal LLM request paths (use GetValue() directly) +func (e EnvVar) MarshalJSON() ([]byte, error) { + type envVarAlias EnvVar + out := envVarAlias(e) + if e.FromEnv { + // Redact the resolved value but keep the env var reference and from_env flag + // so the UI still knows which env var backs this field. + redacted := e.Redacted() + if redacted != nil { + out = envVarAlias(*redacted) + } + } + return sonic.Marshal(out) +} + // UnmarshalJSON unmarshals the value from JSON. func (e *EnvVar) UnmarshalJSON(data []byte) error { // This is always going to be value @@ -259,6 +290,17 @@ func (e *EnvVar) IsFromEnv() bool { return e.FromEnv } +// IsSet returns true if the EnvVar has a resolved value or an environment variable reference. +// This should be used instead of GetValue() != "" when checking whether a field was configured, +// because env var references may have an empty Val before resolution (e.g., when the env var +// is not available in the current environment). +func (e *EnvVar) IsSet() bool { + if e == nil { + return false + } + return e.Val != "" || e.EnvVar != "" +} + // GetValue returns the value. func (e *EnvVar) GetValue() string { if e == nil { @@ -298,3 +340,14 @@ func (e *EnvVar) CoerceBool(defaultValue bool) bool { } return val } + +// IsDefined returns true if the EnvVar has a source (static value or env key) +func (e *EnvVar) IsDefined() bool { + if e == nil { + return false + } + if e.IsFromEnv() { + return e.EnvVar != "" + } + return e.Val != "" +} diff --git a/core/schemas/envvar_test.go b/core/schemas/envvar_test.go index 5a451b5058..9b22673ae8 100644 --- a/core/schemas/envvar_test.go +++ b/core/schemas/envvar_test.go @@ -419,3 +419,191 @@ func TestEnvVar_IsRedacted(t *testing.T) { }) } } + +// TestEnvVar_IsSet verifies the semantic difference between GetValue() != "" and IsSet(). +// IsSet() must return true when the EnvVar references an env var (regardless of whether +// that env var has been resolved to a non-empty Val). This is the property that the +// BeforeSave hooks rely on so env var references survive persistence. +func TestEnvVar_IsSet(t *testing.T) { + tests := []struct { + name string + input *EnvVar + expected bool + }{ + { + name: "nil envvar", + input: nil, + expected: false, + }, + { + name: "completely empty", + input: &EnvVar{}, + expected: false, + }, + { + name: "only Val set (plain value)", + input: &EnvVar{Val: "abc"}, + expected: true, + }, + { + name: "only EnvVar reference set (env not resolved on this server)", + input: &EnvVar{EnvVar: "env.MISSING", FromEnv: true}, + expected: true, + }, + { + name: "Val and EnvVar both set (env was resolved)", + input: &EnvVar{Val: "resolved-secret", EnvVar: "env.X", FromEnv: true}, + expected: true, + }, + { + name: "FromEnv true but no reference and no value", + input: &EnvVar{FromEnv: true}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.input.IsSet(); got != tt.expected { + t.Errorf("IsSet() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues verifies that any EnvVar marshaled +// to JSON with FromEnv=true is automatically masked, regardless of whether the +// surrounding code remembered to call Redacted() explicitly. This is the defense-in-depth +// guarantee that prevents env-resolved secrets from leaking through unredacted fields. +func TestEnvVar_MarshalJSON_AutoRedactsEnvBackedValues(t *testing.T) { + tests := []struct { + name string + input EnvVar + wantValue string + wantEnvVar string + wantFromEnv bool + }{ + { + name: "env-backed long secret is redacted", + input: EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.OPENAI_API_KEY", FromEnv: true}, + wantValue: "sk-1************************mnop", + wantEnvVar: "env.OPENAI_API_KEY", + wantFromEnv: true, + }, + { + name: "env-backed short secret is fully masked", + input: EnvVar{Val: "12345678", EnvVar: "env.SHORT", FromEnv: true}, + wantValue: "********", + wantEnvVar: "env.SHORT", + wantFromEnv: true, + }, + { + name: "env-backed unresolved on this server keeps empty value", + input: EnvVar{Val: "", EnvVar: "env.MISSING", FromEnv: true}, + wantValue: "", + wantEnvVar: "env.MISSING", + wantFromEnv: true, + }, + { + name: "plain value (not from env) is NOT redacted", + input: EnvVar{Val: "2024-10-21", EnvVar: "", FromEnv: false}, + wantValue: "2024-10-21", + wantEnvVar: "", + wantFromEnv: false, + }, + { + name: "empty plain value passes through", + input: EnvVar{Val: "", EnvVar: "", FromEnv: false}, + wantValue: "", + wantEnvVar: "", + wantFromEnv: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.input) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got struct { + Value string `json:"value"` + EnvVar string `json:"env_var"` + FromEnv bool `json:"from_env"` + } + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal of marshaled output failed: %v", err) + } + if got.Value != tt.wantValue { + t.Errorf("value: got %q, want %q", got.Value, tt.wantValue) + } + if got.EnvVar != tt.wantEnvVar { + t.Errorf("env_var: got %q, want %q", got.EnvVar, tt.wantEnvVar) + } + if got.FromEnv != tt.wantFromEnv { + t.Errorf("from_env: got %v, want %v", got.FromEnv, tt.wantFromEnv) + } + }) + } +} + +// TestEnvVar_MarshalJSON_DoesNotMutateOriginal ensures the auto-redaction in MarshalJSON +// does not mutate the receiver. The inference path calls GetValue() to build the actual +// HTTP request to the LLM provider, so the original Val must remain intact. +func TestEnvVar_MarshalJSON_DoesNotMutateOriginal(t *testing.T) { + original := EnvVar{Val: "real-secret-value", EnvVar: "env.SECRET", FromEnv: true} + if _, err := json.Marshal(original); err != nil { + t.Fatalf("Marshal failed: %v", err) + } + if original.Val != "real-secret-value" { + t.Errorf("MarshalJSON mutated Val: got %q, want %q", original.Val, "real-secret-value") + } + if original.GetValue() != "real-secret-value" { + t.Errorf("GetValue() returns mutated value: got %q", original.GetValue()) + } +} + +// TestEnvVar_MarshalJSON_RoundTripIsRedacted verifies that a marshaled-then-unmarshaled +// env-backed EnvVar is recognized as redacted. The merge logic in provider_keys.go relies +// on this so it can detect "the UI sent back the same redacted value, don't overwrite". +func TestEnvVar_MarshalJSON_RoundTripIsRedacted(t *testing.T) { + original := EnvVar{Val: "sk-1234567890abcdefghijklmnop", EnvVar: "env.KEY", FromEnv: true} + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var roundTripped EnvVar + if err := json.Unmarshal(data, &roundTripped); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if !roundTripped.IsRedacted() { + t.Errorf("Round-tripped env-backed value should be IsRedacted, got Val=%q", roundTripped.Val) + } + if roundTripped.EnvVar != "env.KEY" { + t.Errorf("env_var reference lost in round-trip: got %q, want %q", roundTripped.EnvVar, "env.KEY") + } +} + +// TestEnvVar_MarshalJSON_DoesNotAffectGetValue is a critical safety net: marshaling an +// EnvVar to JSON must NOT change what GetValue() returns. The inference path uses +// GetValue() to build outgoing LLM requests; if marshaling were to mutate the value, +// every request after a UI fetch would silently start using the redacted mask as the +// API key. +func TestEnvVar_MarshalJSON_DoesNotAffectGetValue(t *testing.T) { + os.Setenv("MY_REAL_API_KEY", "sk-real-secret-1234567890abcdef") + defer os.Unsetenv("MY_REAL_API_KEY") + + ev := NewEnvVar("env.MY_REAL_API_KEY") + if ev.GetValue() != "sk-real-secret-1234567890abcdef" { + t.Fatalf("setup: GetValue() = %q, want resolved env value", ev.GetValue()) + } + + // Marshaling would redact in the JSON output, but must not touch the in-memory Val. + if _, err := json.Marshal(ev); err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + if ev.GetValue() != "sk-real-secret-1234567890abcdef" { + t.Errorf("GetValue() returns mutated value after MarshalJSON: got %q", ev.GetValue()) + } +} diff --git a/core/schemas/images.go b/core/schemas/images.go index d16df42a10..fba3c2c08a 100644 --- a/core/schemas/images.go +++ b/core/schemas/images.go @@ -69,8 +69,24 @@ type BifrostImageGenerationResponse struct { // - Size on ImageGenerationResponseParameters (from request params if not in response) // - Quality (low, medium, high, auto) only func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) { + if r == nil || req == nil { + return + } numInputImages, size, quality := getNumInputImagesSizeAndQualityFromRequest(req) + // Backfill Model from whichever inner request carries it. Some provider APIs + // (notably OpenAI /v1/images/*) omit model in the response body. + if r.Model == "" { + switch { + case req.ImageGenerationRequest != nil: + r.Model = req.ImageGenerationRequest.Model + case req.ImageEditRequest != nil: + r.Model = req.ImageEditRequest.Model + case req.ImageVariationRequest != nil: + r.Model = req.ImageVariationRequest.Model + } + } + // Backfill NumInputImages if numInputImages > 0 { if r.Usage == nil { @@ -96,6 +112,22 @@ func (r *BifrostImageGenerationResponse) BackfillParams(req *BifrostRequest) { } } +// getModelFromRequest extracts the model from any image-related request. +func getModelFromRequest(req *BifrostRequest) string { + if req == nil { + return "" + } + switch { + case req.ImageGenerationRequest != nil: + return req.ImageGenerationRequest.Model + case req.ImageEditRequest != nil: + return req.ImageEditRequest.Model + case req.ImageVariationRequest != nil: + return req.ImageVariationRequest.Model + } + return "" +} + // getNumInputImagesSizeAndQualityFromRequest extracts request params for cost calculation. // Quality is only returned when it is one of low, medium, high, auto. func getNumInputImagesSizeAndQualityFromRequest(req *BifrostRequest) (numInputImages int, size string, quality string) { @@ -151,10 +183,12 @@ func normalizeImageQuality(q string) string { } type ImageGenerationResponseParameters struct { - Background string `json:"background,omitempty"` - OutputFormat string `json:"output_format,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` + Background string `json:"background,omitempty"` + OutputFormat string `json:"output_format,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + FinishReasons []*string `json:"finish_reasons,omitempty"` + Seeds []int `json:"seeds,omitempty"` } type ImageData struct { @@ -254,7 +288,7 @@ type ImageInput struct { } type ImageEditParameters struct { - Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", + Type *string `json:"type,omitempty"` // "inpainting", "outpainting", "background_removal", "remove_background", "erase_object", "recolor", "search_replace", "control_sketch", "control_structure", "style_guide", "style_transfer", "upscale_fast", "upscale_creative", "upscale_conservative" Background *string `json:"background,omitempty"` // "transparent", "opaque", "auto" InputFidelity *string `json:"input_fidelity,omitempty"` // "low", "high" Mask []byte `json:"mask,omitempty"` diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index 72b43cbc8e..af87cdc743 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -20,8 +20,25 @@ var ( ErrOAuth2TokenExpired = errors.New("oauth2 token expired") ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid") ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed") + ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session") + ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server") + ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired") ) +// MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires +// the user to authenticate before tool execution can proceed. +type MCPUserOAuthRequiredError struct { + MCPClientID string `json:"mcp_client_id"` + MCPClientName string `json:"mcp_client_name"` + AuthorizeURL string `json:"authorize_url"` + SessionID string `json:"session_id"` + Message string `json:"message"` +} + +func (e *MCPUserOAuthRequiredError) Error() string { + return e.Message +} + // MCPConfig represents the configuration for MCP integration in Bifrost. // It enables tool auto-discovery and execution from local and external MCP servers. type MCPConfig struct { @@ -46,9 +63,10 @@ type MCPConfig struct { } type MCPToolManagerConfig struct { - ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` - MaxAgentDepth int `json:"max_agent_depth"` - CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + ToolExecutionTimeout time.Duration `json:"tool_execution_timeout"` + MaxAgentDepth int `json:"max_agent_depth"` + CodeModeBindingLevel CodeModeBindingLevel `json:"code_mode_binding_level,omitempty"` // How tools are exposed in VFS: "server" or "tool" + DisableAutoToolInject bool `json:"disable_auto_tool_inject,omitempty"` // When true, MCP tools are not injected into requests by default } const ( @@ -68,41 +86,48 @@ const ( type MCPAuthType string const ( - MCPAuthTypeNone MCPAuthType = "none" // No authentication - MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.) - MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication + MCPAuthTypeNone MCPAuthType = "none" // No authentication + MCPAuthTypeHeaders MCPAuthType = "headers" // Header-based authentication (API keys, etc.) + MCPAuthTypeOauth MCPAuthType = "oauth" // OAuth 2.0 authentication (server-level, admin authenticates once) + MCPAuthTypePerUserOauth MCPAuthType = "per_user_oauth" // Per-user OAuth 2.0 authentication (each user authenticates individually) ) // MCPClientConfig defines tool filtering for an MCP client. type MCPClientConfig struct { - ID string `json:"client_id"` // Client ID - Name string `json:"name"` // Client name - IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client - ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) - ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) - StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) - AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) - OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) - State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) - Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) - InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) - ToolsToExecute []string `json:"tools_to_execute,omitempty"` // Include-only list. + ID string `json:"client_id"` // Client ID + Name string `json:"name"` // Client name + IsCodeModeClient bool `json:"is_code_mode_client"` // Whether the client is a code mode client + ConnectionType MCPConnectionType `json:"connection_type"` // How to connect (HTTP, STDIO, SSE, or InProcess) + ConnectionString *EnvVar `json:"connection_string,omitempty"` // HTTP or SSE URL (required for HTTP or SSE connections) + StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty"` // STDIO configuration (required for STDIO connections) + AuthType MCPAuthType `json:"auth_type"` // Authentication type (none, headers, or oauth) + OauthConfigID *string `json:"oauth_config_id,omitempty"` // OAuth config ID (references oauth_configs table) + State string `json:"state,omitempty"` // Connection state (connected, disconnected, error) + Headers map[string]EnvVar `json:"headers,omitempty"` // Headers to send with the request (for headers auth type) + AllowedExtraHeaders WhiteList `json:"allowed_extra_headers,omitempty"` // Allowlist of request-level headers that callers may forward to this MCP server at execution time + InProcessServer *server.MCPServer `json:"-"` // MCP server instance for in-process connections (Go package only) + ToolsToExecute WhiteList `json:"tools_to_execute,omitempty"` // Include-only list. // ToolsToExecute semantics: // - ["*"] => all tools are included // - [] => no tools are included (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => include only the specified tools - ToolsToAutoExecute []string `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. + ToolsToAutoExecute WhiteList `json:"tools_to_auto_execute,omitempty"` // Auto-execute list. // ToolsToAutoExecute semantics: // - ["*"] => all tools are auto-executed // - [] => no tools are auto-executed (deny-by-default) // - nil/omitted => treated as [] (no tools) // - ["tool1", "tool2"] => auto-execute only the specified tools // Note: If a tool is in ToolsToAutoExecute but not in ToolsToExecute, it will be skipped. - IsPingAvailable bool `json:"is_ping_available"` // Whether the MCP server supports ping for health checks (default: true). If false, uses listTools for health checks. - ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) - ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) - ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + IsPingAvailable *bool `json:"is_ping_available,omitempty"` // Whether the MCP server supports ping for health checks (nil/true = ping; false = listTools). Defaults to true. + ToolSyncInterval time.Duration `json:"tool_sync_interval,omitempty"` // Per-client override for tool sync interval (0 = use global, negative = disabled) + ToolPricing map[string]float64 `json:"tool_pricing,omitempty"` // Tool pricing for each tool (cost per execution) + ConfigHash string `json:"-"` // Config hash for reconciliation (not serialized) + AllowOnAllVirtualKeys bool `json:"allow_on_all_virtual_keys"` // Whether to allow the MCP client to run on all virtual keys + + // Discovered tools for per-user OAuth clients (persisted so they survive restart) + DiscoveredTools map[string]ChatTool `json:"-"` // Discovered tool schemas keyed by prefixed name + DiscoveredToolNameMapping map[string]string `json:"-"` // Mapping from sanitized tool names to original MCP names } // NewMCPClientConfigFromMap creates a new MCP client config from a map[string]any. @@ -147,6 +172,9 @@ func (c *MCPClientConfig) HttpHeaders(ctx context.Context, oauth2Provider OAuth2 for key, value := range c.Headers { headers[key] = value.GetValue() } + case MCPAuthTypePerUserOauth: + // Per-user OAuth: headers are injected per-call in executeToolInternal, not at connection level + return headers, nil case MCPAuthTypeNone: // No headers to add default: @@ -179,9 +207,10 @@ type MCPStdioConfig struct { type MCPConnectionState string const ( - MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use - MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected - MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used + MCPConnectionStateConnected MCPConnectionState = "connected" // Client is connected and ready to use + MCPConnectionStateDisconnected MCPConnectionState = "disconnected" // Client is not connected + MCPConnectionStateError MCPConnectionState = "error" // Client is in an error state, and cannot be used + MCPConnectionStatePendingTools MCPConnectionState = "pending_tools" // Connected but tools not yet populated ) // MCPClientState represents a connected MCP client with its configuration and tools. diff --git a/core/schemas/models.go b/core/schemas/models.go index 5a0e8588c1..32b82bb104 100644 --- a/core/schemas/models.go +++ b/core/schemas/models.go @@ -138,7 +138,7 @@ type Model struct { ID string `json:"id"` CanonicalSlug *string `json:"canonical_slug,omitempty"` Name *string `json:"name,omitempty"` - Deployment *string `json:"deployment,omitempty"` // Name of the actual deployment + Alias *string `json:"alias,omitempty"` // Provider API identifier this model alias maps to (e.g. Azure deployment name, Bedrock ARN) Created *int64 `json:"created,omitempty"` ContextLength *int `json:"context_length,omitempty"` MaxInputTokens *int `json:"max_input_tokens,omitempty"` diff --git a/core/schemas/models_test.go b/core/schemas/models_test.go index 3e60fdda76..b9748952bd 100644 --- a/core/schemas/models_test.go +++ b/core/schemas/models_test.go @@ -94,7 +94,7 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { Error: &ErrorField{Message: "unauthorized"}, ExtraFields: BifrostErrorExtraFields{ Provider: "openai", - ModelRequested: "gpt-4", + OriginalModelRequested: "gpt-4", }, } keyStatus := KeyStatus{ @@ -112,6 +112,6 @@ func TestKeyStatusMarshalJSON_PreservesErrorFields(t *testing.T) { // Error fields other than key_statuses should be preserved dataStr := string(data) assert.Contains(t, dataStr, `"unauthorized"`) - assert.Contains(t, dataStr, `"model_requested":"gpt-4"`) + assert.Contains(t, dataStr, `"original_model_requested":"gpt-4"`) assert.Contains(t, dataStr, `"status_code":401`) } diff --git a/core/schemas/mux.go b/core/schemas/mux.go index f899f41739..24943d3fbd 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -1258,6 +1258,10 @@ func (responsesResp *BifrostResponsesResponse) ToBifrostChatResponse() *BifrostC Videos: responsesResp.Videos, } + if responsesResp.ID != nil { + chatResp.ID = *responsesResp.ID + } + // Create Choices from ResponsesResponse if len(responsesResp.Output) > 0 { // Convert ResponsesMessages back to ChatMessages @@ -2013,3 +2017,362 @@ func (cr *BifrostChatResponse) ToBifrostResponsesStreamResponse(state *ChatToRes return responses } + +// ToBifrostChatResponse converts a BifrostResponsesStreamResponse chunk to a BifrostChatResponse (chat.completion.chunk). +func (rsr *BifrostResponsesStreamResponse) ToBifrostChatResponse() *BifrostChatResponse { + if rsr == nil { + return nil + } + + extraFields := rsr.ExtraFields + extraFields.RequestType = ChatCompletionStreamRequest + + resp := &BifrostChatResponse{ + Object: "chat.completion.chunk", + ExtraFields: extraFields, + SearchResults: rsr.SearchResults, + Videos: rsr.Videos, + Citations: rsr.Citations, + } + + if rsr.Response != nil { + if rsr.Response.ID != nil { + resp.ID = *rsr.Response.ID + } + resp.Created = rsr.Response.CreatedAt + resp.Model = rsr.Response.Model + } + + switch rsr.Type { + case ResponsesStreamResponseTypeOutputTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Content: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeReasoningSummaryTextDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Reasoning: rsr.Delta, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeRefusalDelta: + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Refusal: rsr.Refusal, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeOutputItemAdded: + if rsr.Item == nil || rsr.Item.Type == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + switch *rsr.Item.Type { + case ResponsesMessageTypeFunctionCall: + if rsr.Item.ResponsesToolMessage == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + funcType := "function" + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Type: &funcType, + ID: rsr.Item.ResponsesToolMessage.CallID, + Function: ChatAssistantMessageToolCallFunction{ + Name: rsr.Item.ResponsesToolMessage.Name, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesMessageTypeMessage: + role := "assistant" + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + Role: &role, + }, + }, + }, + } + return resp + + default: + // reasoning, file_search_call, web_search_call, etc. — no chat equivalent, + // actual content arrives via separate delta events. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + + case ResponsesStreamResponseTypeFunctionCallArgumentsDelta: + if rsr.Delta == nil { + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } + var idx uint16 + if rsr.OutputIndex != nil && *rsr.OutputIndex > 0 { + idx = uint16(*rsr.OutputIndex - 1) + } + + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{ + ToolCalls: []ChatAssistantMessageToolCall{ + { + Index: idx, + Function: ChatAssistantMessageToolCallFunction{ + Arguments: *rsr.Delta, + }, + }, + }, + }, + }, + }, + } + return resp + + case ResponsesStreamResponseTypeCompleted, ResponsesStreamResponseTypeIncomplete: + finishReason := string(BifrostFinishReasonStop) + if rsr.Type == ResponsesStreamResponseTypeIncomplete { + finishReason = string(BifrostFinishReasonLength) + } + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + FinishReason: &finishReason, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + if rsr.Response != nil { + if rsr.Response.Usage != nil { + resp.Usage = rsr.Response.Usage.ToBifrostLLMUsage() + } + // Check for tool_calls finish reason + if rsr.Type == ResponsesStreamResponseTypeCompleted { + for _, output := range rsr.Response.Output { + if output.Type != nil && *output.Type == ResponsesMessageTypeFunctionCall { + finishReason = string(BifrostFinishReasonToolCalls) + resp.Choices[0].FinishReason = &finishReason + break + } + } + } + } + return resp + + default: + // Lifecycle events (created, in_progress, content_part.added/done, output_text.done, + // output_item.done, function_call_arguments.done, etc.) → empty chat chunk with no content. + resp.Choices = []BifrostResponseChoice{ + { + Index: 0, + ChatStreamResponseChoice: &ChatStreamResponseChoice{ + Delta: &ChatStreamResponseChoiceDelta{}, + }, + }, + } + return resp + } +} + +// ============================================================================= +// RESPONSE CONVERSION METHODS +// ============================================================================= + +// ToBifrostTextCompletionResponse converts a BifrostChatResponse to a BifrostTextCompletionResponse +func (cr *BifrostChatResponse) ToBifrostTextCompletionResponse() *BifrostTextCompletionResponse { + if cr == nil { + return nil + } + + if len(cr.Choices) == 0 { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + choice := cr.Choices[0] + + // Handle streaming response choice + if choice.ChatStreamResponseChoice != nil && choice.ChatStreamResponseChoice.Delta != nil { + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: choice.ChatStreamResponseChoice.Delta.Content, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Handle non-streaming response choice + if choice.ChatNonStreamResponseChoice != nil { + msg := choice.ChatNonStreamResponseChoice.Message + var textContent *string + if msg != nil && msg.Content != nil { + if msg.Content.ContentStr != nil { + textContent = msg.Content.ContentStr + } else if len(msg.Content.ContentBlocks) > 0 { + var sb strings.Builder + for _, block := range msg.Content.ContentBlocks { + if block.Text != nil { + sb.WriteString(*block.Text) + } + } + if sb.Len() > 0 { + s := sb.String() + textContent = &s + } + } + } + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Choices: []BifrostResponseChoice{ + { + Index: 0, + TextCompletionResponseChoice: &TextCompletionResponseChoice{ + Text: textContent, + }, + FinishReason: choice.FinishReason, + LogProbs: choice.LogProbs, + }, + }, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } + } + + // Fallback case - return basic response structure + return &BifrostTextCompletionResponse{ + ID: cr.ID, + Model: cr.Model, + Object: "text_completion", + SystemFingerprint: cr.SystemFingerprint, + Usage: cr.Usage, + ExtraFields: BifrostResponseExtraFields{ + RequestType: TextCompletionRequest, + ChunkIndex: cr.ExtraFields.ChunkIndex, + Provider: cr.ExtraFields.Provider, + OriginalModelRequested: cr.ExtraFields.OriginalModelRequested, + Latency: cr.ExtraFields.Latency, + RawResponse: cr.ExtraFields.RawResponse, + CacheDebug: cr.ExtraFields.CacheDebug, + ProviderResponseHeaders: cr.ExtraFields.ProviderResponseHeaders, + }, + } +} diff --git a/core/schemas/oauth.go b/core/schemas/oauth.go index 1a953c3cc6..7ff3bb9362 100644 --- a/core/schemas/oauth.go +++ b/core/schemas/oauth.go @@ -7,7 +7,7 @@ import ( // OauthProvider interface defines OAuth operations type OAuth2Provider interface { - // GetAccessToken retrieves the access token for a given oauth_config_id + // GetAccessToken retrieves the access token for a given oauth_config_id (server-level OAuth) GetAccessToken(ctx context.Context, oauthConfigID string) (string, error) // RefreshAccessToken refreshes the access token for a given oauth_config_id @@ -18,6 +18,31 @@ type OAuth2Provider interface { // RevokeToken revokes the OAuth token RevokeToken(ctx context.Context, oauthConfigID string) error + + // Per-user OAuth methods + + // GetUserAccessToken retrieves the access token for a per-user OAuth session. + // If the token is expired, it automatically attempts a refresh. + GetUserAccessToken(ctx context.Context, sessionToken string) (string, error) + + // GetUserAccessTokenByIdentity retrieves the upstream access token for a user + // identified by virtualKeyID, userID, or sessionToken (fallback), for a specific + // MCP client. Tokens looked up by identity persist across sessions. + GetUserAccessTokenByIdentity(ctx context.Context, virtualKeyID, userID, sessionToken, mcpClientID string) (string, error) + + // InitiateUserOAuthFlow creates a per-user OAuth session and returns the authorization URL. + // Returns (flow initiation details, session ID for polling, error). + InitiateUserOAuthFlow(ctx context.Context, oauthConfigID string, mcpClientID string, redirectURI string) (*OAuth2FlowInitiation, string, error) + + // CompleteUserOAuthFlow handles the OAuth callback for a per-user flow. + // Returns the session token that the user should send on subsequent requests. + CompleteUserOAuthFlow(ctx context.Context, state string, code string) (string, error) + + // RefreshUserAccessToken refreshes a per-user OAuth access token. + RefreshUserAccessToken(ctx context.Context, sessionToken string) error + + // RevokeUserToken revokes a per-user OAuth token and marks the session as revoked. + RevokeUserToken(ctx context.Context, sessionToken string) error } // OauthConfig represents OAuth client configuration diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index f9ea18a4b3..5e0d068718 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -313,9 +313,15 @@ type ObservabilityPlugin interface { // // Implementations should: // - Convert the trace to their backend's format - // - Send the trace to the backend (can be async) + // - Send the trace to the backend (can be async, but see retention note below) // - Handle errors gracefully (log and continue) // // The context passed is a fresh background context, not the request context. + // + // Retention: implementations MUST NOT retain the *Trace pointer after Inject + // returns. The caller releases the trace back to a sync.Pool immediately after + // Inject completes, so any background goroutine that still references it will + // race with pool reuse. If a plugin needs to forward the trace asynchronously, + // it must copy the data it needs before returning. Inject(ctx context.Context, trace *Trace) error } diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 6fe0615b06..5d28002f9f 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -8,18 +8,18 @@ import ( ) const ( - DefaultMaxRetries = 0 - DefaultRetryBackoffInitial = 500 * time.Millisecond - DefaultRetryBackoffMax = 5 * time.Second + DefaultMaxRetries = 0 + DefaultRetryBackoffInitial = 500 * time.Millisecond + DefaultRetryBackoffMax = 5 * time.Second DefaultRequestTimeoutInSeconds = 30 - DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops - DefaultBufferSize = 5000 - DefaultConcurrency = 1000 - DefaultStreamBufferSize = 256 - DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection - DefaultMaxConnsPerHost = 5000 - MaxConnsPerHostUpperBound = 10000 - DefaultMaxIdleConnsPerHost = 40 + DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops + DefaultBufferSize = 5000 + DefaultConcurrency = 1000 + DefaultStreamBufferSize = 256 + DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection + DefaultMaxConnsPerHost = 5000 + MaxConnsPerHostUpperBound = 10000 + DefaultMaxIdleConnsPerHost = 40 ) // Pre-defined errors for provider operations @@ -52,18 +52,18 @@ const ( // - When marshaling to JSON: a time.Duration is converted to milliseconds type NetworkConfig struct { // BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama) - BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) - ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) - DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests - MaxRetries int `json:"max_retries"` // Maximum number of retries - RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) - RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) - InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections - CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections + BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional) + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional) + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests + MaxRetries int `json:"max_retries"` // Maximum number of retries + RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds) + RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds) + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections + CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s) - MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) - EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) - BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") + MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000) + EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock) + BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-") } // UnmarshalJSON customizes JSON unmarshaling for NetworkConfig. @@ -409,67 +409,6 @@ type CustomProviderConfig struct { RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock) } -type PricingOverrideMatchType string - -const ( - PricingOverrideMatchExact PricingOverrideMatchType = "exact" - PricingOverrideMatchWildcard PricingOverrideMatchType = "wildcard" - PricingOverrideMatchRegex PricingOverrideMatchType = "regex" -) - -// ProviderPricingOverride contains a partial pricing patch applied at lookup time. -// Any nil field falls back to the base pricing data. -type ProviderPricingOverride struct { - ModelPattern string `json:"model_pattern"` - MatchType PricingOverrideMatchType `json:"match_type"` - RequestTypes []RequestType `json:"request_types,omitempty"` - - // Basic token pricing - InputCostPerToken *float64 `json:"input_cost_per_token,omitempty"` - OutputCostPerToken *float64 `json:"output_cost_per_token,omitempty"` - - // Additional pricing for media - InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` - InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` - - // Character-based pricing - InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` - - // Pricing above 128k tokens - InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` - InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` - InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` - InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` - OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` - - // Pricing above 200k tokens - InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` - OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` - CacheCreationInputTokenCostAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"` - CacheReadInputTokenCostAbove200kTokens *float64 `json:"cache_read_input_token_cost_above_200k_tokens,omitempty"` - - // Cache and batch pricing - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"` - InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` - OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` - - // Image generation pricing - InputCostPerImageToken *float64 `json:"input_cost_per_image_token,omitempty"` - OutputCostPerImageToken *float64 `json:"output_cost_per_image_token,omitempty"` - InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` - OutputCostPerImage *float64 `json:"output_cost_per_image,omitempty"` - OutputCostPerImageAbove1024x1024Pixels *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"` - OutputCostPerImageAbove1024x1024PixelsPremium *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove2048x2048Pixels *float64 `json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"` - OutputCostPerImageAbove4096x4096Pixels *float64 `json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"` - OutputCostPerImageLowQuality *float64 `json:"output_cost_per_image_low_quality,omitempty"` - OutputCostPerImageMediumQuality *float64 `json:"output_cost_per_image_medium_quality,omitempty"` - OutputCostPerImageHighQuality *float64 `json:"output_cost_per_image_high_quality,omitempty"` - OutputCostPerImageAutoQuality *float64 `json:"output_cost_per_image_auto_quality,omitempty"` - CacheReadInputImageTokenCost *float64 `json:"cache_read_input_image_token_cost,omitempty"` -} - // IsOperationAllowed checks if a specific operation is allowed for this custom provider func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool { if cpc == nil || cpc.AllowedRequests == nil { @@ -485,14 +424,13 @@ type ProviderConfig struct { NetworkConfig NetworkConfig `json:"network_config"` // Network configuration ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings // Logger instance, can be provided by the user or bifrost default logger is used if not provided - Logger Logger `json:"-"` - ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) - SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) - StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) - CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` - OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` - PricingOverrides []ProviderPricingOverride `json:"pricing_overrides,omitempty"` + Logger Logger `json:"-"` + ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false) + SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false) + StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false) + CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` + OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"` } // OpenAIConfig holds OpenAI-specific provider configuration. diff --git a/core/schemas/realtime.go b/core/schemas/realtime.go index e1e20d7bf4..ec4fd6789d 100644 --- a/core/schemas/realtime.go +++ b/core/schemas/realtime.go @@ -19,33 +19,75 @@ const ( // Server-to-client event types (received from the provider, forwarded to client) const ( - RTEventSessionCreated RealtimeEventType = "session.created" - RTEventSessionUpdated RealtimeEventType = "session.updated" - RTEventConversationCreated RealtimeEventType = "conversation.created" - RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" - RTEventConversationItemDone RealtimeEventType = "conversation.item.done" - RTEventResponseCreated RealtimeEventType = "response.created" - RTEventResponseDone RealtimeEventType = "response.done" - RTEventResponseTextDelta RealtimeEventType = "response.text.delta" - RTEventResponseTextDone RealtimeEventType = "response.text.done" - RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" - RTEventResponseAudioDone RealtimeEventType = "response.audio.done" - RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" - RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" - RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" - RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" - RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" - RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" - RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" - RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" - RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" - RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" - RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" - RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" - RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" - RTEventError RealtimeEventType = "error" + RTEventSessionCreated RealtimeEventType = "session.created" + RTEventSessionUpdated RealtimeEventType = "session.updated" + RTEventConversationCreated RealtimeEventType = "conversation.created" + RTEventConversationItemAdded RealtimeEventType = "conversation.item.added" + RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" + RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved" + RTEventConversationItemDone RealtimeEventType = "conversation.item.done" + RTEventResponseCreated RealtimeEventType = "response.created" + RTEventResponseDone RealtimeEventType = "response.done" + RTEventResponseTextDelta RealtimeEventType = "response.text.delta" + RTEventResponseTextDone RealtimeEventType = "response.text.done" + RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" + RTEventResponseAudioDone RealtimeEventType = "response.audio.done" + RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" + RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" + RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" + RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" + RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" + RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" + RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated" + RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" + RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" + RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" + RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" + RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" + RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" + RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" + RTEventError RealtimeEventType = "error" ) +// IsRealtimeConversationItemEventType reports whether the event carries a +// canonical conversation item payload after provider translation. +func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool { + switch eventType { + case RTEventConversationItemCreate, + RTEventConversationItemAdded, + RTEventConversationItemCreated, + RTEventConversationItemRetrieved, + RTEventConversationItemDone: + return true + default: + return false + } +} + +// IsRealtimeUserInputEvent reports whether the event represents a finalized +// user input item in the canonical Bifrost realtime schema. +func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Role == "user" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeToolOutputEvent reports whether the event represents a finalized +// tool output item in the canonical Bifrost realtime schema. +func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Type == "function_call_output" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized +// input-audio transcript in the canonical Bifrost realtime schema. +func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool { + return event != nil && event.Type == RTEventInputAudioTransCompleted +} + // BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events. // Provider converters translate between this format and the provider-native protocol. type BifrostRealtimeEvent struct { @@ -58,36 +100,42 @@ type BifrostRealtimeEvent struct { Audio []byte `json:"audio,omitempty"` Error *RealtimeError `json:"error,omitempty"` + // ExtraParams preserves provider-specific top-level event fields that are not + // promoted into the common Bifrost schema. + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` + // RawData preserves the original provider event for pass-through or debugging. RawData json.RawMessage `json:"raw_data,omitempty"` } // RealtimeSession describes session configuration for the Realtime connection. type RealtimeSession struct { - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Instructions string `json:"instructions,omitempty"` - Voice string `json:"voice,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` - TurnDetection json.RawMessage `json:"turn_detection,omitempty"` - InputAudioFormat string `json:"input_audio_format,omitempty"` - OutputAudioType string `json:"output_audio_type,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Voice string `json:"voice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` + TurnDetection json.RawMessage `json:"turn_detection,omitempty"` + InputAudioFormat string `json:"input_audio_format,omitempty"` + OutputAudioType string `json:"output_audio_type,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeItem represents a conversation item in the Realtime protocol. type RealtimeItem struct { - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Role string `json:"role,omitempty"` - Status string `json:"status,omitempty"` - Content json.RawMessage `json:"content,omitempty"` - Name string `json:"name,omitempty"` - CallID string `json:"call_id,omitempty"` - Arguments string `json:"arguments,omitempty"` - Output string `json:"output,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Status string `json:"status,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + CallID string `json:"call_id,omitempty"` + Arguments string `json:"arguments,omitempty"` + Output string `json:"output,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeDelta carries incremental content for streaming events. @@ -103,10 +151,28 @@ type RealtimeDelta struct { // RealtimeError describes an error from the Realtime API. type RealtimeError struct { - Type string `json:"type,omitempty"` - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` - Param string `json:"param,omitempty"` + Type string `json:"type,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param string `json:"param,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` +} + +// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint +// shape the client called so providers can preserve versioned behavior. +type RealtimeSessionEndpointType string + +const ( + RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets" + RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions" +) + +// RealtimeSessionRoute describes a provider-registered public route for +// ephemeral-token creation. +type RealtimeSessionRoute struct { + Path string + EndpointType RealtimeSessionEndpointType + DefaultProvider ModelProvider } // RealtimeProvider is an optional interface that providers can implement to @@ -116,6 +182,129 @@ type RealtimeProvider interface { SupportsRealtimeAPI() bool RealtimeWebSocketURL(key Key, model string) string RealtimeHeaders(key Key) map[string]string + // SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange. + SupportsRealtimeWebRTC() bool + // ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange. + // The provider owns the HTTP specifics (URL, headers, body format). + // session may be nil if the signaling format doesn't include session config. + ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error) ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error) + // ShouldStartRealtimeTurn reports whether the canonical client-side event + // should start pre-hooks. Providers without an explicit turn-start signal + // return false and rely on finalize-time fallback hooks. + ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool + // RealtimeTurnFinalEvent returns the canonical provider event that completes + // a turn and should trigger post-hooks. + RealtimeTurnFinalEvent() RealtimeEventType + RealtimeWebRTCDataChannelLabel() string + RealtimeWebSocketSubprotocol() string + ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool + ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool +} + +// RealtimeLegacyWebRTCProvider is an optional interface for providers that +// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime). +// Only checked for legacy integration routes via type assertion. +// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP +// but targets the provider's legacy/beta endpoint. +type RealtimeLegacyWebRTCProvider interface { + ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError) +} + +// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from +// their native wire payloads without coupling handlers to a specific protocol. +type RealtimeUsageExtractor interface { + ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage + ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage +} + +// RealtimeSessionProvider is an optional interface for providers that can mint +// short-lived client secrets for browser/client-side Realtime connections. +// Checked via type assertion: provider.(RealtimeSessionProvider). +type RealtimeSessionProvider interface { + CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError) +} + +// ParseRealtimeEvent decodes a client/provider realtime event while preserving +// unknown top-level fields in ExtraParams for provider-specific round-tripping. +func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) { + type realtimeEventAlias struct { + Type RealtimeEventType `json:"type"` + EventID string `json:"event_id,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Delta *RealtimeDelta `json:"delta,omitempty"` + Audio []byte `json:"audio,omitempty"` + Error *RealtimeError `json:"error,omitempty"` + } + + var alias realtimeEventAlias + if err := Unmarshal(raw, &alias); err != nil { + return nil, err + } + + event := &BifrostRealtimeEvent{ + Type: alias.Type, + EventID: alias.EventID, + Session: alias.Session, + Item: alias.Item, + Delta: alias.Delta, + Audio: alias.Audio, + Error: alias.Error, + } + + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, err + } + savedSession := root["session"] + savedItem := root["item"] + savedError := root["error"] + for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} { + delete(root, key) + } + if len(root) > 0 { + event.ExtraParams = root + } + if event.Session != nil { + var sessionRoot map[string]json.RawMessage + if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil { + for _, key := range []string{ + "id", "model", "modalities", "instructions", "voice", "temperature", + "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools", + } { + delete(sessionRoot, key) + } + if len(sessionRoot) > 0 { + event.Session.ExtraParams = sessionRoot + } + } + } + if event.Item != nil { + var itemRoot map[string]json.RawMessage + if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil { + for _, key := range []string{ + "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output", + } { + delete(itemRoot, key) + } + if len(itemRoot) > 0 { + event.Item.ExtraParams = itemRoot + } + } + } + if event.Error != nil { + var errorRoot map[string]json.RawMessage + if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil { + for _, key := range []string{"type", "code", "message", "param"} { + delete(errorRoot, key) + } + if len(errorRoot) > 0 { + event.Error.ExtraParams = errorRoot + } + } + } + + return event, nil } diff --git a/core/schemas/realtime_client_secrets.go b/core/schemas/realtime_client_secrets.go new file mode 100644 index 0000000000..ae97b573a1 --- /dev/null +++ b/core/schemas/realtime_client_secrets.go @@ -0,0 +1,66 @@ +package schemas + +import ( + "bytes" + "encoding/json" + "strings" +) + +// ParseRealtimeClientSecretBody parses a realtime client-secret request body +// into a mutable raw JSON map while preserving unknown fields. +func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) { + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err) + } + return root, nil +} + +// ExtractRealtimeClientSecretModel extracts the model from either session.model +// or the legacy top-level model field. +func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) { + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) { + var session map[string]json.RawMessage + if err := Unmarshal(sessionJSON, &session); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err) + } + if modelJSON, ok := session["model"]; ok { + var sessionModel string + if err := Unmarshal(modelJSON, &sessionModel); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err) + } + if strings.TrimSpace(sessionModel) != "" { + return strings.TrimSpace(sessionModel), nil + } + } + } + + if modelJSON, ok := root["model"]; ok { + var model string + if err := Unmarshal(modelJSON, &model); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err) + } + if strings.TrimSpace(model) != "" { + return strings.TrimSpace(model), nil + } + } + + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil) +} + +// NewRealtimeClientSecretBodyError builds a standard invalid-request style error +// for HTTP realtime client-secret request parsing/validation. +func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError { + return &BifrostError{ + IsBifrostError: false, + StatusCode: Ptr(status), + Error: &ErrorField{ + Type: Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: BifrostErrorExtraFields{ + RequestType: RealtimeRequest, + }, + } +} diff --git a/core/schemas/realtime_client_secrets_test.go b/core/schemas/realtime_client_secrets_test.go new file mode 100644 index 0000000000..dfd8f8b1d3 --- /dev/null +++ b/core/schemas/realtime_client_secrets_test.go @@ -0,0 +1,40 @@ +package schemas + +import ( + "encoding/json" + "testing" +) + +func TestExtractRealtimeClientSecretModel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "openai/gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview") + } +} + +func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } +} diff --git a/core/schemas/realtime_test.go b/core/schemas/realtime_test.go new file mode 100644 index 0000000000..69e9e403c8 --- /dev/null +++ b/core/schemas/realtime_test.go @@ -0,0 +1,68 @@ +package schemas + +import "testing" + +func TestIsRealtimeConversationItemEventType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + eventType RealtimeEventType + want bool + }{ + {name: "create", eventType: RTEventConversationItemCreate, want: true}, + {name: "added", eventType: RTEventConversationItemAdded, want: true}, + {name: "created", eventType: RTEventConversationItemCreated, want: true}, + {name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true}, + {name: "done", eventType: RTEventConversationItemDone, want: true}, + {name: "response done", eventType: RTEventResponseDone, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want { + t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want) + } + }) + } +} + +func TestRealtimeCanonicalEventClassifiers(t *testing.T) { + t.Parallel() + + userEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemAdded, + Item: &RealtimeItem{ + Role: "user", + Type: "message", + }, + } + if !IsRealtimeUserInputEvent(userEvent) { + t.Fatal("expected conversation.item.added user event to be classified as realtime user input") + } + if IsRealtimeToolOutputEvent(userEvent) { + t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output") + } + + toolEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemRetrieved, + Item: &RealtimeItem{ + Type: "function_call_output", + }, + } + if !IsRealtimeToolOutputEvent(toolEvent) { + t.Fatal("expected function_call_output item to be classified as realtime tool output") + } + if IsRealtimeUserInputEvent(toolEvent) { + t.Fatal("did not expect function_call_output item to be classified as realtime user input") + } + + transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted} + if !IsRealtimeInputTranscriptEvent(transcriptEvent) { + t.Fatal("expected input audio transcription completion to be classified as transcript event") + } + if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) { + t.Fatal("did not expect input audio transcription delta to be classified as transcript event") + } +} diff --git a/core/schemas/responses.go b/core/schemas/responses.go index adc1d5d07f..04a9be8bd2 100644 --- a/core/schemas/responses.go +++ b/core/schemas/responses.go @@ -1406,6 +1406,14 @@ type ResponsesTool struct { // Not in OpenAI's schemas, but sent by a few providers (Anthropic, Bedrock are some of them) CacheControl *CacheControl `json:"cache_control,omitempty"` + // Anthropic-native tool flags promoted to the neutral layer. All optional; + // ignored by providers that don't support them. Gated per ProviderFeatures + // in core/providers/anthropic/types.go. + DeferLoading *bool `json:"defer_loading,omitempty"` // Anthropic advanced-tool-use: defer loading of tool definition + AllowedCallers []string `json:"allowed_callers,omitempty"` // Anthropic advanced-tool-use: which callers can invoke this tool + InputExamples []ChatToolInputExample `json:"input_examples,omitempty"` // Anthropic tool-examples-2025-10-29: example inputs for the tool + EagerInputStreaming *bool `json:"eager_input_streaming,omitempty"` // Anthropic fine-grained-tool-streaming-2025-05-14 + *ResponsesToolFunction *ResponsesToolFileSearch *ResponsesToolComputerUsePreview @@ -1463,6 +1471,38 @@ func (t ResponsesTool) MarshalJSON() ([]byte, error) { return nil, err } } + // Anthropic-native tool flags promoted to the neutral layer. Must be + // emitted here (before the type-specific merge) so the wire format carries + // them to providers that gate features on these keys. Without this block + // MarshalJSON silently drops the fields despite their json tags. + if t.DeferLoading != nil { + if data, err = sjson.SetBytes(data, "defer_loading", *t.DeferLoading); err != nil { + return nil, err + } + } + if len(t.AllowedCallers) > 0 { + callersBytes, callersErr := MarshalSorted(t.AllowedCallers) + if callersErr != nil { + return nil, callersErr + } + if data, err = sjson.SetRawBytes(data, "allowed_callers", callersBytes); err != nil { + return nil, err + } + } + if len(t.InputExamples) > 0 { + examplesBytes, examplesErr := MarshalSorted(t.InputExamples) + if examplesErr != nil { + return nil, examplesErr + } + if data, err = sjson.SetRawBytes(data, "input_examples", examplesBytes); err != nil { + return nil, err + } + } + if t.EagerInputStreaming != nil { + if data, err = sjson.SetBytes(data, "eager_input_streaming", *t.EagerInputStreaming); err != nil { + return nil, err + } + } // Marshal the type-specific embedded struct and merge its fields var typeBytes []byte @@ -1566,6 +1606,32 @@ func (t *ResponsesTool) UnmarshalJSON(data []byte) error { } t.CacheControl = &cc } + // Anthropic-native tool flags. Mirror the emit side in MarshalJSON above — + // without these reads, a round-trip silently drops the fields. + if v, ok := raw["defer_loading"].(bool); ok { + t.DeferLoading = Ptr(v) + } + if v, ok := raw["allowed_callers"]; ok { + bytes, err := MarshalSorted(v) + if err != nil { + return err + } + if err := Unmarshal(bytes, &t.AllowedCallers); err != nil { + return err + } + } + if v, ok := raw["input_examples"]; ok { + bytes, err := MarshalSorted(v) + if err != nil { + return err + } + if err := Unmarshal(bytes, &t.InputExamples); err != nil { + return err + } + } + if v, ok := raw["eager_input_streaming"].(bool); ok { + t.EagerInputStreaming = Ptr(v) + } // Based on type, unmarshal into the appropriate embedded struct switch t.Type { diff --git a/core/schemas/serialization_test.go b/core/schemas/serialization_test.go index 66b720b680..94b9a27dae 100644 --- a/core/schemas/serialization_test.go +++ b/core/schemas/serialization_test.go @@ -781,6 +781,204 @@ func TestResponsesTool_MarshalJSON_RoundTrip(t *testing.T) { } } +// TestResponsesTool_RoundTrip_AnthropicFields ensures the Anthropic-native tool +// flags promoted onto ResponsesTool (defer_loading, allowed_callers, +// input_examples, eager_input_streaming) survive a full Marshal→Unmarshal→ +// Marshal cycle. Before MarshalJSON/UnmarshalJSON were taught to handle these +// keys, all four were silently dropped at the JSON boundary. +func TestResponsesTool_RoundTrip_AnthropicFields(t *testing.T) { + original := ResponsesTool{ + Type: ResponsesToolTypeFunction, + Name: Ptr("lookup"), + Description: Ptr("lookup something"), + DeferLoading: Ptr(true), + AllowedCallers: []string{"direct", "agent"}, + EagerInputStreaming: Ptr(false), + InputExamples: []ChatToolInputExample{ + {Input: json.RawMessage(`{"q":"hello"}`), Description: Ptr("basic")}, + {Input: json.RawMessage(`{"q":"world"}`)}, + }, + ResponsesToolFunction: &ResponsesToolFunction{ + Parameters: &ToolFunctionParameters{}, + }, + } + + data, err := Marshal(original) + require.NoError(t, err) + + // All four keys must appear in the wire bytes. + for _, key := range []string{`"defer_loading"`, `"allowed_callers"`, `"input_examples"`, `"eager_input_streaming"`} { + assert.Contains(t, string(data), key, + "%s must be emitted by MarshalJSON — otherwise it is silently dropped", key) + } + + var decoded ResponsesTool + require.NoError(t, Unmarshal(data, &decoded)) + + require.NotNil(t, decoded.DeferLoading) + assert.True(t, *decoded.DeferLoading) + assert.Equal(t, []string{"direct", "agent"}, decoded.AllowedCallers) + require.NotNil(t, decoded.EagerInputStreaming) + assert.False(t, *decoded.EagerInputStreaming) + require.Len(t, decoded.InputExamples, 2) + assert.JSONEq(t, `{"q":"hello"}`, string(decoded.InputExamples[0].Input)) + require.NotNil(t, decoded.InputExamples[0].Description) + assert.Equal(t, "basic", *decoded.InputExamples[0].Description) + assert.JSONEq(t, `{"q":"world"}`, string(decoded.InputExamples[1].Input)) + + // Second-round marshal must be byte-stable. + data2, err := Marshal(decoded) + require.NoError(t, err) + assert.Equal(t, string(data), string(data2), "round-trip must be stable") +} + +// TestChatTool_MarshalJSON_EnforcesUnion verifies that the custom codec +// canonicalizes mixed-state ChatTools on the wire, regardless of what the +// caller populated in memory. Exactly one variant's fields survive marshal — +// matching Type — so downstream provider converters can't misinterpret or +// forward stray fields from a different shape. +func TestChatTool_MarshalJSON_EnforcesUnion(t *testing.T) { + t.Run("function_type_clears_custom_and_server_tool_fields", func(t *testing.T) { + tool := ChatTool{ + Type: ChatToolTypeFunction, + Function: &ChatToolFunction{Name: "get_weather"}, + // Mixed state: server-tool + custom fields also populated. + Custom: &ChatToolCustom{}, + Name: "leaked_name", + MaxUses: Ptr(5), + DisplayWidthPx: Ptr(1280), + MCPServerName: "leaked_server", + } + data, err := Marshal(tool) + require.NoError(t, err) + raw := string(data) + + assert.Contains(t, raw, `"type":"function"`) + assert.Contains(t, raw, `"get_weather"`) + for _, leak := range []string{`"custom"`, `"leaked_name"`, `"max_uses"`, `"display_width_px"`, `"mcp_server_name"`} { + assert.NotContains(t, raw, leak, "function-type wire must not carry %s", leak) + } + }) + + t.Run("custom_type_clears_function_and_server_tool_fields", func(t *testing.T) { + tool := ChatTool{ + Type: ChatToolTypeCustom, + Custom: &ChatToolCustom{Format: &ChatToolCustomFormat{Type: "text"}}, + Name: "my_custom", + // Leaks + Function: &ChatToolFunction{Name: "should_be_stripped"}, + MaxUses: Ptr(5), + } + data, err := Marshal(tool) + require.NoError(t, err) + raw := string(data) + + assert.Contains(t, raw, `"type":"custom"`) + assert.Contains(t, raw, `"my_custom"`) // custom tool retains top-level Name + assert.Contains(t, raw, `"format"`) // custom's format field + assert.NotContains(t, raw, `"function"`) + assert.NotContains(t, raw, `"should_be_stripped"`) + assert.NotContains(t, raw, `"max_uses"`) + }) + + t.Run("server_tool_type_clears_function_and_custom", func(t *testing.T) { + tool := ChatTool{ + Type: "web_search_20260209", + Name: "web_search", + MaxUses: Ptr(5), + AllowedCallers: []string{"direct"}, + // Leaks + Function: &ChatToolFunction{Name: "should_be_stripped"}, + Custom: &ChatToolCustom{}, + } + data, err := Marshal(tool) + require.NoError(t, err) + raw := string(data) + + assert.Contains(t, raw, `"type":"web_search_20260209"`) + assert.Contains(t, raw, `"web_search"`) + assert.Contains(t, raw, `"max_uses":5`) + assert.Contains(t, raw, `"allowed_callers":["direct"]`) + assert.NotContains(t, raw, `"function"`) + assert.NotContains(t, raw, `"custom"`) + assert.NotContains(t, raw, `"should_be_stripped"`) + }) +} + +// TestChatTool_UnmarshalJSON_NormalizesMixedInput verifies that tolerant +// decode of a mixed-shape payload produces a canonical single-variant struct +// so downstream provider conversion code doesn't have to defend against +// the untrusted shape. +func TestChatTool_UnmarshalJSON_NormalizesMixedInput(t *testing.T) { + t.Run("function_type_mixed_with_server_fields_normalizes", func(t *testing.T) { + // Caller sends a function tool but also includes server-tool metadata. + raw := []byte(`{ + "type":"function", + "function":{"name":"get_weather"}, + "name":"stray_server_name", + "max_uses":5, + "display_width_px":1280 + }`) + var tool ChatTool + require.NoError(t, Unmarshal(raw, &tool)) + + assert.Equal(t, ChatToolTypeFunction, tool.Type) + require.NotNil(t, tool.Function) + assert.Equal(t, "get_weather", tool.Function.Name) + assert.Empty(t, tool.Name, "function-type must nil top-level Name (lives in Function.Name)") + assert.Nil(t, tool.MaxUses) + assert.Nil(t, tool.DisplayWidthPx) + }) + + t.Run("server_tool_type_mixed_with_function_normalizes", func(t *testing.T) { + // Caller sends a server-tool but also includes function. + raw := []byte(`{ + "type":"web_search_20260209", + "name":"web_search", + "max_uses":5, + "function":{"name":"stray"} + }`) + var tool ChatTool + require.NoError(t, Unmarshal(raw, &tool)) + + assert.Equal(t, ChatToolType("web_search_20260209"), tool.Type) + assert.Equal(t, "web_search", tool.Name) + require.NotNil(t, tool.MaxUses) + assert.Equal(t, 5, *tool.MaxUses) + assert.Nil(t, tool.Function, "server-tool must nil Function") + assert.Nil(t, tool.Custom, "server-tool must nil Custom") + }) +} + +// TestChatTool_RoundTrip_SurvivesMixedInput verifies that a mixed-input +// payload, once canonicalized by Unmarshal and re-emitted by Marshal, drops +// the stray fields and produces a deterministic single-variant wire format. +func TestChatTool_RoundTrip_SurvivesMixedInput(t *testing.T) { + raw := []byte(`{ + "type":"web_search_20260209", + "name":"web_search", + "max_uses":5, + "function":{"name":"stray"}, + "custom":{"format":{"type":"text"}} + }`) + var tool ChatTool + require.NoError(t, Unmarshal(raw, &tool)) + + out, err := Marshal(tool) + require.NoError(t, err) + outStr := string(out) + assert.NotContains(t, outStr, `"function"`) + assert.NotContains(t, outStr, `"custom"`) + assert.Contains(t, outStr, `"web_search_20260209"`) + + // Second pass must be byte-stable (critical for prompt caching keys). + var tool2 ChatTool + require.NoError(t, Unmarshal(out, &tool2)) + out2, err := Marshal(tool2) + require.NoError(t, err) + assert.Equal(t, string(out), string(out2), "round-trip must be stable") +} + func TestToolFunctionParameters_ExplicitEmptyObjectPreserved(t *testing.T) { var params ToolFunctionParameters err := Unmarshal([]byte(`{}`), ¶ms) @@ -1069,3 +1267,90 @@ func TestResponsesTool_UnmarshalJSON_NormalizesVersionedToolTypes(t *testing.T) }) } } + +// TestSonic_ChatTool_AnnotationsNeverSerialized verifies that MCPToolAnnotations +// (json:"-") are never included in the JSON payload sent to providers. +func TestSonic_ChatTool_AnnotationsNeverSerialized(t *testing.T) { + readOnly := true + destructive := false + + tool := ChatTool{ + Type: ChatToolTypeFunction, + Function: &ChatToolFunction{ + Name: "read_file", + Description: Ptr("Reads a file from the filesystem"), + Parameters: &ToolFunctionParameters{ + Type: "object", + Properties: NewOrderedMapFromPairs(KV("path", map[string]interface{}{"type": "string"})), + Required: []string{"path"}, + }, + }, + Annotations: &MCPToolAnnotations{ + Title: "File Reader", + ReadOnlyHint: &readOnly, + DestructiveHint: &destructive, + IdempotentHint: Ptr(true), + }, + } + + output, err := Marshal(tool) + require.NoError(t, err) + + s := string(output) + + // Annotations must be absent — json:"-" must suppress the entire field + assert.NotContains(t, s, "annotations", "annotations field must not appear in provider payload") + assert.NotContains(t, s, "readOnlyHint", "readOnlyHint must not appear in provider payload") + assert.NotContains(t, s, "destructiveHint", "destructiveHint must not appear in provider payload") + assert.NotContains(t, s, "idempotentHint", "idempotentHint must not appear in provider payload") + assert.NotContains(t, s, "File Reader", "annotation title must not appear in provider payload") + + // The function definition itself must still be present + assert.Contains(t, s, "read_file", "function name must be in payload") + assert.Contains(t, s, "path", "parameter must be in payload") +} + +// TestSonic_ChatTool_DeepCopy_AnnotationsPreserved verifies that DeepCopyChatTool +// correctly copies Annotations so they survive any clone-based flows. +func TestSonic_ChatTool_DeepCopy_AnnotationsPreserved(t *testing.T) { + readOnly := true + idempotent := false + + original := ChatTool{ + Type: ChatToolTypeFunction, + Function: &ChatToolFunction{ + Name: "query_db", + }, + Annotations: &MCPToolAnnotations{ + Title: "DB Query", + ReadOnlyHint: &readOnly, + IdempotentHint: &idempotent, + }, + } + + copied := DeepCopyChatTool(original) + + require.NotNil(t, copied.Annotations) + assert.Equal(t, "DB Query", copied.Annotations.Title) + assert.Equal(t, true, *copied.Annotations.ReadOnlyHint) + assert.Equal(t, false, *copied.Annotations.IdempotentHint) + assert.Nil(t, copied.Annotations.DestructiveHint) + assert.Nil(t, copied.Annotations.OpenWorldHint) + + // Verify it's a true deep copy — mutations don't bleed back + *original.Annotations.ReadOnlyHint = false + assert.True(t, *copied.Annotations.ReadOnlyHint, "copy must not share pointer with original") +} + +// TestSonic_ChatTool_DeepCopy_NilAnnotationsStaysNil verifies that a tool +// without annotations deep-copies cleanly with Annotations remaining nil. +func TestSonic_ChatTool_DeepCopy_NilAnnotationsStaysNil(t *testing.T) { + original := ChatTool{ + Type: ChatToolTypeFunction, + Function: &ChatToolFunction{Name: "plain_tool"}, + } + + copied := DeepCopyChatTool(original) + + assert.Nil(t, copied.Annotations, "Annotations should stay nil when original has none") +} diff --git a/core/schemas/trace.go b/core/schemas/trace.go index 9a69980d3c..d6862d4d4e 100644 --- a/core/schemas/trace.go +++ b/core/schemas/trace.go @@ -8,6 +8,7 @@ import ( // Trace represents a distributed trace that captures the full lifecycle of a request type Trace struct { + RequestID string // Request ID for the trace TraceID string // Unique identifier for this trace ParentID string // Parent trace ID from incoming W3C traceparent header RootSpan *Span // The root span of this trace @@ -15,6 +16,7 @@ type Trace struct { StartTime time.Time // When the trace started EndTime time.Time // When the trace completed Attributes map[string]any // Additional attributes for the trace + PluginLogs []PluginLogEntry // Plugin log entries accumulated during request processing mu sync.Mutex // Mutex for thread-safe span operations } @@ -37,15 +39,49 @@ func (t *Trace) GetSpan(spanID string) *Span { return nil } +// GetRequestID retrieves the request ID from the trace +func (t *Trace) GetRequestID() string { + t.mu.Lock() + defer t.mu.Unlock() + return t.RequestID +} + +// SetRequestID sets the request ID for the trace +func (t *Trace) SetRequestID(requestID string) { + t.mu.Lock() + defer t.mu.Unlock() + t.RequestID = requestID +} + // Reset clears the trace for reuse from pool func (t *Trace) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + t.RequestID = "" t.TraceID = "" t.ParentID = "" t.RootSpan = nil + for i := range t.Spans { + t.Spans[i] = nil + } t.Spans = t.Spans[:0] t.StartTime = time.Time{} t.EndTime = time.Time{} t.Attributes = nil + for i := range t.PluginLogs { + t.PluginLogs[i] = PluginLogEntry{} + } + t.PluginLogs = t.PluginLogs[:0] +} + +// AppendPluginLogs appends plugin log entries to the trace in a thread-safe manner. +func (t *Trace) AppendPluginLogs(logs []PluginLogEntry) { + if len(logs) == 0 { + return + } + t.mu.Lock() + t.PluginLogs = append(t.PluginLogs, logs...) + t.mu.Unlock() } // Span represents a single operation within a trace diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go index 06f6487f8c..23c5d4cc4c 100644 --- a/core/schemas/tracer.go +++ b/core/schemas/tracer.go @@ -14,7 +14,8 @@ type SpanHandle interface{} // This is the return type for tracer's streaming accumulation methods. type StreamAccumulatorResult struct { RequestID string // Request ID - Model string // Model used + RequestedModel string // Original model requested by the caller + ResolvedModel string // Actual model used by the provider (equals RequestedModel when no alias mapping exists) Provider ModelProvider // Provider used Status string // Status of the stream Latency int64 // Latency in milliseconds @@ -38,7 +39,8 @@ type StreamAccumulatorResult struct { type Tracer interface { // CreateTrace creates a new trace with optional parent ID and returns the trace ID. // The parentID can be extracted from W3C traceparent headers for distributed tracing. - CreateTrace(parentID string) string + // The requestID is optional and can be used to identify the request. + CreateTrace(parentID string, requestID ...string) string // EndTrace completes a trace and returns the trace data for observation/export. // After this call, the trace is removed from active tracking and returned for cleanup. @@ -68,7 +70,7 @@ type Tracer interface { // PopulateLLMResponseAttributes populates all LLM-specific response attributes on the span. // This includes output messages, tokens, usage stats, and error information if present. - PopulateLLMResponseAttributes(handle SpanHandle, resp *BifrostResponse, err *BifrostError) + PopulateLLMResponseAttributes(ctx *BifrostContext, handle SpanHandle, resp *BifrostResponse, err *BifrostError) // StoreDeferredSpan stores a span handle for later completion (used for streaming requests). // The span handle is stored keyed by trace ID so it can be retrieved when the stream completes. @@ -111,6 +113,14 @@ type Tracer interface { // The ctx parameter must contain the stream end indicator for proper final chunk detection. ProcessStreamingChunk(traceID string, isFinalChunk bool, result *BifrostResponse, err *BifrostError) *StreamAccumulatorResult + // AttachPluginLogs appends plugin log entries to the trace identified by traceID. + // Thread-safe. Should be called after plugin hooks complete, before trace completion. + AttachPluginLogs(traceID string, logs []PluginLogEntry) + + // CompleteAndFlushTrace ends a trace, exports it to observability plugins, and + // releases the trace resources. Used by transports that bypass normal HTTP trace completion. + CompleteAndFlushTrace(traceID string) + // Stop releases resources associated with the tracer. // Should be called during shutdown to stop background goroutines. Stop() @@ -121,7 +131,7 @@ type Tracer interface { type NoOpTracer struct{} // CreateTrace returns an empty string (no trace created). -func (n *NoOpTracer) CreateTrace(_ string) string { return "" } +func (n *NoOpTracer) CreateTrace(_ string, _ ...string) string { return "" } // EndTrace returns nil (no trace to end). func (n *NoOpTracer) EndTrace(_ string) *Trace { return nil } @@ -144,7 +154,7 @@ func (n *NoOpTracer) AddEvent(_ SpanHandle, _ string, _ map[string]any) {} func (n *NoOpTracer) PopulateLLMRequestAttributes(_ SpanHandle, _ *BifrostRequest) {} // PopulateLLMResponseAttributes does nothing. -func (n *NoOpTracer) PopulateLLMResponseAttributes(_ SpanHandle, _ *BifrostResponse, _ *BifrostError) { +func (n *NoOpTracer) PopulateLLMResponseAttributes(_ *BifrostContext, _ SpanHandle, _ *BifrostResponse, _ *BifrostError) { } // StoreDeferredSpan does nothing. @@ -176,6 +186,12 @@ func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, return nil } +// AttachPluginLogs does nothing. +func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {} + +// CompleteAndFlushTrace does nothing. +func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {} + // Stop does nothing. func (n *NoOpTracer) Stop() {} diff --git a/core/schemas/transcriptions.go b/core/schemas/transcriptions.go index 7308714ed5..1cd801be98 100644 --- a/core/schemas/transcriptions.go +++ b/core/schemas/transcriptions.go @@ -14,15 +14,37 @@ func (r *BifrostTranscriptionRequest) GetRawRequestBody() []byte { } type BifrostTranscriptionResponse struct { - Duration *float64 `json:"duration,omitempty"` // Duration in seconds - Language *string `json:"language,omitempty"` // e.g., "english" - LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` - Segments []TranscriptionSegment `json:"segments,omitempty"` - Task *string `json:"task,omitempty"` // e.g., "transcribe" - Text string `json:"text"` - Usage *TranscriptionUsage `json:"usage,omitempty"` - Words []TranscriptionWord `json:"words,omitempty"` - ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + Duration *float64 `json:"duration,omitempty"` // Duration in seconds + Language *string `json:"language,omitempty"` // e.g., "english" + LogProbs []TranscriptionLogProb `json:"logprobs,omitempty"` + Segments []TranscriptionSegment `json:"segments,omitempty"` + Task *string `json:"task,omitempty"` // e.g., "transcribe" + Text string `json:"text"` + Usage *TranscriptionUsage `json:"usage,omitempty"` + Words []TranscriptionWord `json:"words,omitempty"` + ResponseFormat *string `json:"-"` // Set by provider for non-JSON formats (text, srt, vtt); used by integration response converters + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` +} + +func (r *BifrostTranscriptionResponse) BackfillParams(req *BifrostTranscriptionRequest) { + if r == nil || req == nil || req.Params == nil || req.Params.ResponseFormat == nil { + return + } + r.ResponseFormat = req.Params.ResponseFormat +} + +// IsPlainTextTranscriptionFormat returns true if the given response format +// produces a plain-text response body (not JSON). +func IsPlainTextTranscriptionFormat(format *string) bool { + if format == nil { + return false + } + switch *format { + case "text", "srt", "vtt": + return true + default: + return false + } } type TranscriptionInput struct { @@ -31,17 +53,17 @@ type TranscriptionInput struct { } type TranscriptionParameters struct { - Language *string `json:"language,omitempty"` - Prompt *string `json:"prompt,omitempty"` - ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" - Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) - TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json - Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs) - Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini - MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace - MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace - MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace - MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" + Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) + TimestampGranularities []string `json:"timestamp_granularities,omitempty"` // "word" and/or "segment"; requires response_format=verbose_json + Include []string `json:"include,omitempty"` // Additional response info (e.g., logprobs) + Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini + MaxLength *int `json:"max_length,omitempty"` // Maximum length of the transcription used by HuggingFace + MinLength *int `json:"min_length,omitempty"` // Minimum length of the transcription used by HuggingFace + MaxNewTokens *int `json:"max_new_tokens,omitempty"` // Maximum new tokens to generate used by HuggingFace + MinNewTokens *int `json:"min_new_tokens,omitempty"` // Minimum new tokens to generate used by HuggingFace // Elevenlabs-specific fields AdditionalFormats []TranscriptionAdditionalFormat `json:"additional_formats,omitempty"` @@ -132,4 +154,3 @@ type BifrostTranscriptionStreamResponse struct { Usage *TranscriptionUsage `json:"usage,omitempty"` ExtraFields BifrostResponseExtraFields `json:"extra_fields"` } - diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 7c59020f40..5e61c84ade 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -879,6 +879,30 @@ func DeepCopyChatTool(original ChatTool) ChatTool { } } + // Deep copy Annotations if present + if original.Annotations != nil { + copyAnnotations := &MCPToolAnnotations{ + Title: original.Annotations.Title, + } + if original.Annotations.ReadOnlyHint != nil { + v := *original.Annotations.ReadOnlyHint + copyAnnotations.ReadOnlyHint = &v + } + if original.Annotations.DestructiveHint != nil { + v := *original.Annotations.DestructiveHint + copyAnnotations.DestructiveHint = &v + } + if original.Annotations.IdempotentHint != nil { + v := *original.Annotations.IdempotentHint + copyAnnotations.IdempotentHint = &v + } + if original.Annotations.OpenWorldHint != nil { + v := *original.Annotations.OpenWorldHint + copyAnnotations.OpenWorldHint = &v + } + copyTool.Annotations = copyAnnotations + } + // Deep copy Custom if present if original.Custom != nil { copyTool.Custom = &ChatToolCustom{} diff --git a/core/schemas/videos.go b/core/schemas/videos.go index 9e133c7d52..b1d134889c 100644 --- a/core/schemas/videos.go +++ b/core/schemas/videos.go @@ -156,6 +156,9 @@ func (r *BifrostVideoGenerationResponse) BackfillParams(req *BifrostRequest) { if seconds != nil { r.Seconds = seconds } + if r.Model == "" && req.VideoGenerationRequest != nil { + r.Model = req.VideoGenerationRequest.Model + } } // --- Video Remix --- diff --git a/core/utils.go b/core/utils.go index ed8f40ebf4..0e1b45fefd 100644 --- a/core/utils.go +++ b/core/utils.go @@ -11,6 +11,7 @@ import ( "math/rand" "net" "net/url" + "slices" "strings" "time" @@ -86,19 +87,19 @@ func Ptr[T any](v T) *T { } // providerRequiresKey returns true if the given provider requires an API key for authentication. -// Some providers like Ollama, SGL, and vLLM are keyless and don't require API keys. -func providerRequiresKey(providerKey schemas.ModelProvider, customConfig *schemas.CustomProviderConfig) bool { +func providerRequiresKey(customConfig *schemas.CustomProviderConfig) bool { // Keyless custom providers are not allowed for Bedrock. if customConfig != nil && customConfig.IsKeyLess && customConfig.BaseProviderType != schemas.Bedrock { return false } - return !IsKeylessProvider(providerKey) + return true } -// canProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. -// Some providers like Vertex and Bedrock have their credentials in additional key configs.. +// CanProviderKeyValueBeEmpty returns true if the given provider allows the API key to be empty. +// Some providers like Vertex and Bedrock have their credentials in additional key configs. +// Ollama and SGL are keyless (API Key is optional) but use per-key server URLs. func CanProviderKeyValueBeEmpty(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure + return providerKey == schemas.Vertex || providerKey == schemas.Bedrock || providerKey == schemas.VLLM || providerKey == schemas.Azure || providerKey == schemas.Ollama || providerKey == schemas.SGL } func isKeySkippingAllowed(providerKey schemas.ModelProvider) bool { @@ -131,6 +132,56 @@ func validateRequest(req *schemas.BifrostRequest) *schemas.BifrostError { return nil } +// validateKey validates the given key. +func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) error { + // Validate the key for the provider + switch providerKey { + case schemas.Azure: + if key.AzureKeyConfig == nil { + return fmt.Errorf("azure_key_config is required") + } + if key.AzureKeyConfig.Endpoint.GetValue() == "" { + return fmt.Errorf("azure_key_config.endpoint is required") + } + case schemas.Bedrock: + // Key is valid if either: + // 1. BedrockKeyConfig is provided + // 2. Value is provided and is not empty + if key.BedrockKeyConfig == nil { + if key.Value.GetValue() == "" { + return fmt.Errorf("either value in key or bedrock_key_config is required") + } + key.BedrockKeyConfig = &schemas.BedrockKeyConfig{} + } + case schemas.Vertex: + if key.VertexKeyConfig == nil { + return fmt.Errorf("vertex_key_config is required") + } + case schemas.VLLM: + if key.VLLMKeyConfig == nil { + return fmt.Errorf("vllm_key_config is required") + } + if key.VLLMKeyConfig.URL.GetValue() == "" { + return fmt.Errorf("vllm_key_config.url is required") + } + case schemas.Ollama: + if key.OllamaKeyConfig == nil { + return fmt.Errorf("ollama_key_config is required") + } + if key.OllamaKeyConfig.URL.GetValue() == "" { + return fmt.Errorf("ollama_key_config.url is required") + } + case schemas.SGL: + if key.SGLKeyConfig == nil { + return fmt.Errorf("sgl_key_config is required") + } + if key.SGLKeyConfig.URL.GetValue() == "" { + return fmt.Errorf("sgl_key_config.url is required") + } + } + return nil +} + // IsRateLimitErrorMessage checks if an error message indicates a rate limit issue func IsRateLimitErrorMessage(errorMessage string) bool { if errorMessage == "" { @@ -175,7 +226,7 @@ func newBifrostErrorFromMsg(message string) *schemas.BifrostError { // newBifrostCtxDoneError creates a BifrostError from a cancelled/expired context. // It distinguishes DeadlineExceeded (504 RequestTimedOut) from Canceled (499 RequestCancelled). -func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelProvider, model string, requestType schemas.RequestType, stage string) *schemas.BifrostError { +func newBifrostCtxDoneError(ctx *schemas.BifrostContext, stage string) *schemas.BifrostError { var statusCode int var errorType string var message string @@ -199,11 +250,6 @@ func newBifrostCtxDoneError(ctx *schemas.BifrostContext, provider schemas.ModelP Message: message, Error: ctx.Err(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: requestType, - Provider: provider, - ModelRequested: model, - }, } } @@ -230,6 +276,9 @@ func newBifrostMessageChan(message *schemas.BifrostResponse) chan *schemas.Bifro func clearCtxForFallback(ctx *schemas.BifrostContext) { ctx.ClearValue(schemas.BifrostContextKeyAPIKeyID) ctx.ClearValue(schemas.BifrostContextKeyAPIKeyName) + ctx.ClearValue(schemas.BifrostContextKeyGovernanceIncludeOnlyKeys) + ctx.ClearValue(schemas.BifrostContextKeyChangeRequestType) + ctx.ClearValue(schemas.BifrostContextKeyAttemptTrail) } var supportedBaseProvidersSet = func() map[schemas.ModelProvider]struct{} { @@ -261,11 +310,6 @@ func IsStandardProvider(providerKey schemas.ModelProvider) bool { return ok } -// IsKeylessProvider reports whether providerKey is a keyless provider. -func IsKeylessProvider(providerKey schemas.ModelProvider) bool { - return providerKey == schemas.Ollama || providerKey == schemas.SGL -} - // IsStreamRequestType returns true if the given request type is a stream request. func IsStreamRequestType(reqType schemas.RequestType) bool { return reqType == schemas.TextCompletionStreamRequest || reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest || reqType == schemas.ImageGenerationStreamRequest || reqType == schemas.ImageEditStreamRequest || reqType == schemas.PassthroughStreamRequest || reqType == schemas.WebSocketResponsesRequest || reqType == schemas.RealtimeRequest @@ -336,14 +380,14 @@ func IsFinalChunk(ctx *schemas.BifrostContext) bool { return false } -// GetResponseFields extracts the request type, provider, and model from the result or error -func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, model string) { +// GetResponseFields extracts the request type, provider, original model, and resolved model from the result or error. +func GetResponseFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (requestType schemas.RequestType, provider schemas.ModelProvider, originalModel string, resolvedModel string) { if result != nil { extraFields := result.GetExtraFields() - return extraFields.RequestType, extraFields.Provider, extraFields.ModelRequested + return extraFields.RequestType, extraFields.Provider, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed } if err != nil { - return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.ModelRequested + return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.OriginalModelRequested, err.ExtraFields.ResolvedModelUsed } return } @@ -366,7 +410,9 @@ func GetErrorMessage(err *schemas.BifrostError) string { if err == nil { return "" } - if err.StatusCode != nil { + if err.Error != nil && err.Error.Message != "" { + return err.Error.Message + } else if err.StatusCode != nil { switch *err.StatusCode { case 401: return "unauthorized" @@ -392,8 +438,6 @@ func GetErrorMessage(err *schemas.BifrostError) string { } return fmt.Sprintf("HTTP %d error", *err.StatusCode) } - } else if err.Error != nil && err.Error.Message != "" { - return err.Error.Message } else if err.Type != nil { return *err.Type } else { @@ -544,3 +588,44 @@ func buildSessionKey(providerKey schemas.ModelProvider, sessionID string, model } return "session:" + string(providerKey) + ":" + hashedSessionID + ":" + hashSHA256(discriminator) } + +// isPromptOptionalImageEditType returns true for edit task types that do not require a text prompt. +// It normalises hyphenated variants (e.g. "erase-object") to underscore form before matching. +func isPromptOptionalImageEditType(t *string) bool { + if t == nil { + return false + } + normalized := strings.ToLower(strings.TrimSpace(*t)) + normalized = strings.ReplaceAll(normalized, "-", "_") + return slices.Contains( + []string{"background_removal", "remove_background", "remove_bg", "erase_object", "upscale_fast"}, + normalized, + ) +} + +// wrapConvertedStreamPostHookRunner wraps a PostHookRunner so that streaming +// responses produced by a type-converted request are converted back to the +// caller's original type before the post-hook runs. +func wrapConvertedStreamPostHookRunner(postHookRunner schemas.PostHookRunner, targetType schemas.RequestType) schemas.PostHookRunner { + return func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil { + switch targetType { + case schemas.ChatCompletionRequest: + // text→chat: convert chat stream chunk back to text completion + if result.ChatResponse != nil { + if converted := result.ChatResponse.ToBifrostTextCompletionResponse(); converted != nil { + result = &schemas.BifrostResponse{TextCompletionResponse: converted} + } + } + case schemas.ResponsesRequest: + // chat→responses: convert responses stream chunk back to chat + if result.ResponsesStreamResponse != nil { + if converted := result.ResponsesStreamResponse.ToBifrostChatResponse(); converted != nil { + result = &schemas.BifrostResponse{ChatResponse: converted} + } + } + } + } + return postHookRunner(ctx, result, bifrostErr) + } +} diff --git a/core/version b/core/version index fd4ca57b8d..4cda8f19ed 100644 --- a/core/version +++ b/core/version @@ -1 +1 @@ -1.4.19 +1.5.2 diff --git a/docs/changelogs/v1.4.23.mdx b/docs/changelogs/v1.4.23.mdx new file mode 100644 index 0000000000..2a5f8d2049 --- /dev/null +++ b/docs/changelogs/v1.4.23.mdx @@ -0,0 +1,110 @@ +--- +title: "v1.4.23" +description: "v1.4.23 changelog - 2026-04-18" +--- + + + ```bash + npx -y @maximhq/bifrost --transport-version v1.4.23 + ``` + + + ```bash + docker pull maximhq/bifrost:v1.4.23 + docker run -p 8080:8080 maximhq/bifrost:v1.4.23 + ``` + + + + +## ✨ Features + +- **Claude Opus 4.7 Support** — Added compatibility for Anthropic's Claude Opus 4.7 model, including adaptive thinking, task-budgets beta header, `display` parameter handling, and "xhigh" effort mapping +- **Anthropic Structured Outputs** — Added `response_format` and structured output support for Anthropic models across chat completions and Responses API, including JSON-schema and JSON-object formats with order-preserving merge of additional model request fields (thanks [@emirhanmutlu-natuvion](https://github.com/emirhanmutlu-natuvion)!) +- **MCP Tool Annotations** — Preserve MCP tool annotations (`title`, `readOnly`, `destructive`, `idempotent`, `openWorld`) in bidirectional conversion between MCP tools and Bifrost chat tools so agents can reason about tool behavior +- **Anthropic Server Tools** — Expanded Anthropic chat schema and Responses converters to surface server-side tools (web search, code execution, computer use containers) end-to-end + +## 🐞 Fixed + +- **Provider Queue Shutdown Panic** — Eliminated `send on closed channel` panics in provider queue shutdown by leaving queue channels open and exiting workers via the `done` signal; stale producers transparently re-route to new queues during `UpdateProvider`, with rollback on failed updates +- **OpenAI Tool Result Output** — Flatten array-form `tool_result` output into a newline-joined string before marshaling for the Responses API so strict upstreams (Ollama Cloud, openai-go typed models) no longer reject it with HTTP 400; non-text blocks (images, files) are preserved (thanks [@martingiguere](https://github.com/martingiguere)!) +- **vLLM Token Usage** — Treat `delta.content=""` the same as `nil` in streaming so the synthesis chunk retains its `finish_reason`, restoring token usage attribution in logs and UI +- **Config Schema Validator** — Corrected JSON-path lookups for concurrency and SCIM blocks in the schema validation script, and reformatted `transports/config.schema.json` for readability +- **CI Egress Hardening** — Switched `step-security/harden-runner` from `audit` to `block` across all GitHub Actions workflows with explicit `allowed-endpoints` per job +- **Gemini Tool Outputs** — Handle content block tool outputs in Responses API path for `function_call_output` messages (thanks [@tom-diacono](https://github.com/tom-diacono)!) +- **Bedrock Streaming** — Emit `message_stop` event for Anthropic invoke stream and case-insensitive `anthropic-beta` header merging (thanks [@tefimov](https://github.com/tefimov)!) +- **Bedrock Tool Images** — Preserve image content blocks in tool results when converting Anthropic Messages to Bedrock Converse API (thanks [@Edward-Upton](https://github.com/Edward-Upton)!) +- **Gemini Thinking Level** — Preserved `thinkingLevel` parameters across round-trip conversions and corrected finish reason mapping +- **Anthropic WebSearch** — Removed the Claude Code user agent restriction so WebSearch tool arguments flow for all clients +- **Responses Streaming Errors** — Capture errors mid-stream in the Responses API so transport clients see failures instead of silent termination +- **Anthropic Request Fallbacks** — Dropped fallback fields from outgoing Anthropic requests to avoid schema validation errors +- **Async Context Propagation** — Preserve context values in async requests so downstream handlers retain request-scoped data +- **Custom Providers** — Allow custom providers without a list-models endpoint to accept any model rather than restricting on virtual key registration +- **OTEL Plugin** — Default `insecure` to `true` in config.json and include fallbacks in emitted OTEL metrics +- **Payload Marshalling** — Removed unnecessary marshalling of payload in the transport path +- **Helm mcpClientConfig** — Fixed templating for `mcpClientConfig` (thanks [@crust3780](https://github.com/crust3780)!) +- **Helm Chart** — Refreshed the helm chart with validation fixes and removed the prerelease tag + + + +- fix: OpenAI provider - flatten array-form tool_result output for Responses API (thanks [@martingiguere](https://github.com/martingiguere)!) +- fix: Gemini provider - handle content block tool outputs in Responses API path (thanks [@tom-diacono](https://github.com/tom-diacono)!) +- fix: case-insensitive `anthropic-beta` merge in `MergeBetaHeaders` +- fix: Bedrock provider - emit message_stop event for Anthropic invoke stream (thanks [@tefimov](https://github.com/tefimov)!) +- fix: Bedrock provider - preserve image content in tool results for Converse API (thanks [@Edward-Upton](https://github.com/Edward-Upton)!) +- fix: gemini preserves thinkingLevel parameters during round-trip and finish reason mapping +- fix: WebSearch tool argument handling for all clients by removing the Claude Code user agent restriction +- fix: capture responses streaming API errors +- fix: delete fallbacks from outgoing Anthropic requests +- feat: claude-opus-4-7 compatibility +- fix: token usage for vllm + + + +- chore: upgraded core to v1.4.20 +- fix: preserve context values in async requests +- fix: capture responses streaming API errors +- fix: otel plugin fixes +- fix: allow custom providers without a list models endpoint to register any model + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 +- fix: allow custom providers without a list models endpoint to pass in any model rather than restrict it on vk + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 +- fix: capture responses streaming API errors + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 +- fix: sets default for `insecure` to `true` for config.json +- fix: includes fallbacks in otel metrics + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 + + + +- chore: upgraded core to v1.4.20 and framework to v1.2.39 + + diff --git a/docs/contributing/setting-up-repo.mdx b/docs/contributing/setting-up-repo.mdx index d4ea991fb0..a787c6ba19 100644 --- a/docs/contributing/setting-up-repo.mdx +++ b/docs/contributing/setting-up-repo.mdx @@ -91,6 +91,16 @@ This command will: The `make dev` command handles all setup automatically. You can skip the manual setup steps below if this works for you. +#### Alternative: Using Pulse + +If you prefer [Pulse](https://github.com/Pratham-Mishra04/pulse) over Air for hot reloading, use: + +```bash +make dev-pulse +``` + +This runs the same development environment but uses `pulse.yaml` for hot reloading instead of `.air.toml`. + ### Manual Setup (Alternative) If you prefer to set up components manually: @@ -154,7 +164,8 @@ The Makefile provides numerous commands for development: ### Development Commands ```bash -make dev # Start complete development environment (recommended) +make dev # Start complete development environment using Air for hot reloading +make dev-pulse # Start complete development environment using Pulse for hot reloading make build # Build UI and bifrost-http binary make run # Build and run (no hot reload) make clean # Clean build artifacts @@ -373,6 +384,9 @@ which air || go install github.com/air-verse/air@latest # Check if .air.toml exists in transports/bifrost-http/ ls transports/bifrost-http/.air.toml + +# Alternatively, use Pulse instead of Air +make dev-pulse ``` ### Getting Help diff --git a/docs/deployment-guides/config-json.mdx b/docs/deployment-guides/config-json.mdx new file mode 100644 index 0000000000..413c0e5a51 --- /dev/null +++ b/docs/deployment-guides/config-json.mdx @@ -0,0 +1,313 @@ +--- +title: "Quick Start" +description: "Configure Bifrost using a config.json file — GitOps-friendly, no-UI deployments, and multinode OSS setups" +icon: "file-code" +--- + + +**Full schema reference:** [`https://www.getbifrost.ai/schema`](https://www.getbifrost.ai/schema) + + +`config.json` lets you configure every aspect of Bifrost through a single declarative file. It is the right choice for GitOps workflows, CI/CD pipelines, headless deployments, and multinode OSS setups where a central configuration file is shared across all replicas. + +--- + +## Two Configuration Modes + +Bifrost supports **two mutually exclusive modes**. You cannot run both at the same time. + +| Mode | When | Behaviour | +|------|------|-----------| +| **Web UI / database** | No `config.json`, or `config.json` with `config_store` enabled | Full UI available, configuration stored in SQLite or PostgreSQL | +| **File-based (`config.json`)** | `config.json` present, `config_store` disabled | UI disabled, all config loaded from file at startup, restart required for changes | + + +See [Setting Up](/quickstart/gateway/setting-up#two-configuration-modes) for a full explanation of both modes and how `config_store` bootstrapping works. + + +--- + +## Minimal Working Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + "client": { + "drop_excess_requests": false, + "enable_logging": true + }, + "providers": { + "openai": { + "keys": [ + { + "name": "openai-primary", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + } + }, + "config_store": { + "enabled": false + } +} +``` + +Save this as `config.json` in your app directory and start Bifrost: + +```bash +# NPX +npx -y @maximhq/bifrost -app-dir ./data + +# Docker +docker run -p 8080:8080 \ + -v $(pwd)/data:/app/data \ + -e OPENAI_API_KEY=sk-... \ + -e BIFROST_ENCRYPTION_KEY=your-32-byte-key \ + maximhq/bifrost +``` + +Make your first call: + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +--- + +## Environment Variable References + +Never put secrets directly in `config.json`. Use the `env.` prefix to reference any environment variable: + +```json +{ + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + "providers": { + "openai": { + "keys": [ + { + "name": "primary", + "value": "env.OPENAI_API_KEY", + "weight": 1.0 + } + ] + } + } +} +``` + +Set the actual values through your deployment platform — shell environment, Docker `-e`, Kubernetes Secrets mounted as env vars, or a `.env` file. + +--- + +## Schema Validation + +Add `$schema` to every `config.json` for IDE autocomplete and inline validation: + +```json +{ + "$schema": "https://www.getbifrost.ai/schema" +} +``` + +Editors (VS Code, JetBrains, Neovim with LSP) will show completions and flag invalid fields as you type. + +--- + +## Production Example + +A production-ready file with PostgreSQL storage, multi-provider setup, governance, and common plugins: + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "client": { + "initial_pool_size": 500, + "drop_excess_requests": true, + "enable_logging": true, + "log_retention_days": 90, + "enforce_auth_on_inference": true, + "allow_direct_keys": false, + "allowed_origins": ["https://app.yourcompany.com"] + }, + + "providers": { + "openai": { + "keys": [ + { + "name": "openai-primary", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 120, + "max_retries": 3 + } + }, + "anthropic": { + "keys": [ + { + "name": "anthropic-primary", + "value": "env.ANTHROPIC_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + } + }, + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require" + } + }, + + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require" + } + } +} +``` + +--- + +## Example Configs + +Ready-to-use reference configurations from the [examples/configs](https://github.com/maximhq/bifrost/tree/main/examples/configs) directory on GitHub: + + + + + +| Example | Description | +|---------|-------------| +| [noconfigstorenologstore](https://github.com/maximhq/bifrost/blob/main/examples/configs/noconfigstorenologstore/config.json) | Bare-minimum file-only mode — no database, no UI, providers loaded from file | +| [partial](https://github.com/maximhq/bifrost/blob/main/examples/configs/partial/config.json) | SQLite config store with a minimal provider setup | +| [v1compat](https://github.com/maximhq/bifrost/blob/main/examples/configs/v1compat/config.json) | `"version": 1` for v1.4.x array semantics (empty = allow all) | + + + + + +| Example | Description | +|---------|-------------| +| [withconfigstore](https://github.com/maximhq/bifrost/blob/main/examples/configs/withconfigstore/config.json) | SQLite config store (Web UI enabled) | +| [withconfigstorelogsstorepostgres](https://github.com/maximhq/bifrost/blob/main/examples/configs/withconfigstorelogsstorepostgres/config.json) | PostgreSQL for both config store and logs store | +| [withlogstore](https://github.com/maximhq/bifrost/blob/main/examples/configs/withlogstore/config.json) | SQLite logs store | +| [withobjectstorages3](https://github.com/maximhq/bifrost/blob/main/examples/configs/withobjectstorages3/config.json) | S3 object storage offload for logs | +| [withobjectstoragegcs](https://github.com/maximhq/bifrost/blob/main/examples/configs/withobjectstoragegcs/config.json) | GCS object storage offload for logs | +| [withvectorstoreweaviate](https://github.com/maximhq/bifrost/blob/main/examples/configs/withvectorstoreweaviate/config.json) | Weaviate vector store (with [docker-compose](https://github.com/maximhq/bifrost/blob/main/examples/configs/withvectorstoreweaviate/docker-compose.yml)) | + + + + + +| Example | Description | +|---------|-------------| +| [withsemanticcache](https://github.com/maximhq/bifrost/blob/main/examples/configs/withsemanticcache/config.json) | Semantic cache backed by Weaviate | +| [withsemanticcachevalkey](https://github.com/maximhq/bifrost/blob/main/examples/configs/withsemanticcachevalkey/config.json) | Semantic cache backed by Valkey / Redis | + + + + + +| Example | Description | +|---------|-------------| +| [withauth](https://github.com/maximhq/bifrost/blob/main/examples/configs/withauth/config.json) | Admin username/password auth (`governance.auth_config`) | +| [withvirtualkeys](https://github.com/maximhq/bifrost/blob/main/examples/configs/withvirtualkeys/config.json) | Virtual keys with provider/model allowlists | +| [withteamscustomers](https://github.com/maximhq/bifrost/blob/main/examples/configs/withteamscustomers/config.json) | Teams and customers with budgets and rate limits | +| [withroutingrules](https://github.com/maximhq/bifrost/blob/main/examples/configs/withroutingrules/config.json) | CEL-based routing rules for dynamic provider/model selection | +| [withpricingoverridesnostore](https://github.com/maximhq/bifrost/blob/main/examples/configs/withpricingoverridesnostore/config.json) | Pricing overrides in file-only mode | +| [withpricingoverridessqlite](https://github.com/maximhq/bifrost/blob/main/examples/configs/withpricingoverridessqlite/config.json) | Pricing overrides with SQLite config store | + + + + + +| Example | Description | +|---------|-------------| +| [withobservability](https://github.com/maximhq/bifrost/blob/main/examples/configs/withobservability/config.json) | Prometheus metrics (telemetry always active, custom labels via `client.prometheus_labels`) | +| [withprompushgateway](https://github.com/maximhq/bifrost/blob/main/examples/configs/withprompushgateway/config.json) | Prometheus Push Gateway for multi-instance deployments | +| [withotel](https://github.com/maximhq/bifrost/blob/main/examples/configs/withotel/config.json) | OpenTelemetry traces and metrics | + + + + + +| Example | Description | +|---------|-------------| +| [withdynamicplugin](https://github.com/maximhq/bifrost/blob/main/examples/configs/withdynamicplugin/config.json) | Loading a custom `.so` plugin at startup | +| [withcompat](https://github.com/maximhq/bifrost/blob/main/examples/configs/withcompat/config.json) | SDK compatibility shims (`should_drop_params`, `convert_text_to_chat`) | +| [withframework](https://github.com/maximhq/bifrost/blob/main/examples/configs/withframework/config.json) | Custom model pricing catalog URL and sync interval | +| [withlargepayload](https://github.com/maximhq/bifrost/blob/main/examples/configs/withlargepayload/config.json) | Large payload optimization (streaming without full materialisation) | +| [withwebsocket](https://github.com/maximhq/bifrost/blob/main/examples/configs/withwebsocket/config.json) | WebSocket / Realtime API connection pool tuning | +| [withpostgresmcpclientsinconfig](https://github.com/maximhq/bifrost/blob/main/examples/configs/withpostgresmcpclientsinconfig/config.json) | MCP client definitions seeded from config.json with PostgreSQL store | +| [encryptionmigration](https://github.com/maximhq/bifrost/blob/main/examples/configs/encryptionmigration/config.json) | Migrating to a new encryption key | + + + + + +--- + +## Configuration Guides + + + + Every top-level key, its type, default, and where it is documented + + + Pool size, logging, CORS, header filtering, compat shims, MCP settings + + + OpenAI, Anthropic, Azure, Bedrock, Vertex, Groq, self-hosted + + + config_store, logs_store, vector_store — SQLite, PostgreSQL, object storage + + + Semantic cache, OTel, Maxim, Datadog, custom plugins + + + Virtual keys, budgets, rate limits, routing rules, admin auth + + + Content moderation providers and CEL-based rules (enterprise) + + + +--- + +## Next Steps + +1. Configure [provider keys](/providers/supported-providers/overview) +2. Enable [plugins](/plugins/getting-started) +3. Set up [observability](/features/observability/default) +4. Configure [governance](/features/governance/virtual-keys) +5. Deploy [multiple nodes](/deployment-guides/how-to/multinode) with a shared `config.json` diff --git a/docs/deployment-guides/config-json/client.mdx b/docs/deployment-guides/config-json/client.mdx new file mode 100644 index 0000000000..1a974df77b --- /dev/null +++ b/docs/deployment-guides/config-json/client.mdx @@ -0,0 +1,276 @@ +--- +title: "Client Configuration" +description: "Configure the Bifrost client in config.json — connection pool, logging, CORS, header filtering, compat shims, and MCP settings" +icon: "gear" +--- + +The `client` block controls how Bifrost manages its internal worker pool, request logging, authentication enforcement, header policies, SDK compatibility shims, and MCP agent behaviour. + +--- + +## Connection Pool + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `initial_pool_size` | integer | `300` | Pre-allocated worker goroutines per provider queue | +| `drop_excess_requests` | boolean | `false` | Drop requests when queue is full instead of waiting (returns HTTP 429) | + +A larger pool reduces latency spikes under burst load at the cost of higher baseline memory. `500–1000` is a common starting point for production workloads with multiple providers. + +```json +{ + "client": { + "initial_pool_size": 1000, + "drop_excess_requests": true + } +} +``` + +--- + +## Request & Response Logging + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `enable_logging` | boolean | — | Log all LLM requests and responses | +| `disable_content_logging` | boolean | `false` | Strip message content from logs (keeps metadata only) | +| `log_retention_days` | integer | `365` | Days to retain log entries in the store | +| `logging_headers` | array of strings | `[]` | HTTP request headers to capture in log metadata | + +Set `disable_content_logging: true` for HIPAA / PCI compliance workloads where message content must not be persisted. + +```json +{ + "client": { + "enable_logging": true, + "disable_content_logging": true, + "log_retention_days": 90, + "logging_headers": ["x-request-id", "x-user-id"] + } +} +``` + +--- + +## Security & CORS + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `allowed_origins` | array | `["*"]` | CORS allowed origins (use URIs or `"*"`) | +| `allow_direct_keys` | boolean | `false` | Allow callers to pass provider keys directly in requests | +| `enforce_auth_on_inference` | boolean | `false` | Require auth (virtual key, API key, or user token) on `/v1/*` inference routes | +| `max_request_body_size_mb` | integer | `100` | Maximum allowed request body size in MB | +| `whitelisted_routes` | array of strings | `[]` | Routes that bypass auth middleware | +| `allowed_headers` | array of strings | `[]` | Additional headers permitted for CORS and WebSocket | + +```json +{ + "client": { + "allowed_origins": [ + "https://app.yourcompany.com", + "https://admin.yourcompany.com" + ], + "allow_direct_keys": false, + "enforce_auth_on_inference": true, + "max_request_body_size_mb": 50, + "whitelisted_routes": ["/health", "/metrics"] + } +} +``` + +--- + +## Header Filtering + +Controls which `x-bf-eh-*` extra headers are forwarded to upstream LLM providers. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `header_filter_config.allowlist` | array of strings | `[]` | Only these headers are forwarded (whitelist mode) | +| `header_filter_config.denylist` | array of strings | `[]` | These headers are always blocked | +| `required_headers` | array of strings | `[]` | Headers that must be present on every request (rejected with 400 if missing) | + +When both `allowlist` and `denylist` are empty, all `x-bf-eh-*` headers pass through. Specifying an `allowlist` enables strict whitelist mode — only listed headers are forwarded. + +```json +{ + "client": { + "header_filter_config": { + "allowlist": [ + "x-bf-eh-anthropic-version", + "x-bf-eh-openai-beta" + ], + "denylist": [] + }, + "required_headers": ["x-request-id"] + } +} +``` + +--- + +## Compat Shims + +Compatibility flags that let Bifrost silently adapt request/response shapes for SDK integrations. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `compat.convert_text_to_chat` | boolean | `false` | Wrap legacy `/v1/completions` text requests as chat messages | +| `compat.convert_chat_to_responses` | boolean | `false` | Translate chat completions to Responses API format | +| `compat.should_drop_params` | boolean | `false` | Silently drop unsupported parameters instead of erroring | +| `compat.should_convert_params` | boolean | `false` | Auto-convert parameter values across provider schemas | + +```json +{ + "client": { + "compat": { + "should_drop_params": true, + "convert_text_to_chat": true + } + } +} +``` + +--- + +## MCP Agent Settings + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `mcp_agent_depth` | integer | `10` | Maximum tool-call recursion depth for MCP agent mode | +| `mcp_tool_execution_timeout` | integer | `30` | Timeout per MCP tool execution in seconds | +| `mcp_code_mode_binding_level` | string | — | Code mode binding level: `"server"` or `"tool"` | +| `mcp_tool_sync_interval` | integer | `10` | Global tool sync interval in minutes (`0` = disabled) | +| `mcp_disable_auto_tool_inject` | boolean | `false` | When `true`, MCP tools are not automatically injected into requests | + +```json +{ + "client": { + "mcp_agent_depth": 15, + "mcp_tool_execution_timeout": 60, + "mcp_tool_sync_interval": 10 + } +} +``` + +--- + +## Async Jobs + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `async_job_result_ttl` | integer | `3600` | TTL (seconds) for async job results | +| `disable_db_pings_in_health` | boolean | `false` | Exclude database connectivity from `/health` endpoint checks | + +--- + +## Prometheus Labels + +Add custom labels to every Prometheus metric emitted by Bifrost: + +```json +{ + "client": { + "prometheus_labels": ["environment=production", "region=us-east-1"] + } +} +``` + +--- + +## Authentication + +`governance.auth_config` protects the Bifrost dashboard and management API with username/password auth. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `is_enabled` | boolean | `false` | Enable username/password auth | +| `admin_username` | string | — | Admin username | +| `admin_password` | string | — | Admin password (use `env.` reference) | +| `disable_auth_on_inference` | boolean | `false` | Skip auth check on `/v1/*` inference routes | + +```json +{ + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD", + "disable_auth_on_inference": false + } + } +} +``` + + +A top-level `auth_config` is also accepted for backwards compatibility, but `governance.auth_config` is the preferred location. + + +--- + +## Encryption Key + +```json +{ + "encryption_key": "env.BIFROST_ENCRYPTION_KEY" +} +``` + +| Notes | +|-------| +| Accepts any string; Bifrost derives a 32-byte AES-256 key using Argon2id | +| Can also be set via the `BIFROST_ENCRYPTION_KEY` environment variable | +| Once set and the database is populated, the key cannot be changed without clearing the database | +| Omitting the key stores data in plain text — not recommended for production | + +--- + +## Full Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD", + "disable_auth_on_inference": false + } + }, + + "client": { + "initial_pool_size": 1000, + "drop_excess_requests": true, + + "enable_logging": true, + "disable_content_logging": false, + "log_retention_days": 90, + "logging_headers": ["x-request-id", "x-user-id"], + + "allowed_origins": ["https://app.yourcompany.com"], + "allow_direct_keys": false, + "enforce_auth_on_inference": true, + "max_request_body_size_mb": 100, + + "header_filter_config": { + "allowlist": [], + "denylist": [] + }, + "required_headers": [], + + "compat": { + "should_drop_params": false + }, + + "prometheus_labels": ["environment=production"], + + "mcp_agent_depth": 10, + "mcp_tool_execution_timeout": 30, + + "async_job_result_ttl": 3600 + } +} +``` diff --git a/docs/deployment-guides/config-json/governance.mdx b/docs/deployment-guides/config-json/governance.mdx new file mode 100644 index 0000000000..16ed48115e --- /dev/null +++ b/docs/deployment-guides/config-json/governance.mdx @@ -0,0 +1,333 @@ +--- +title: "Governance" +description: "Seed virtual keys, budgets, rate limits, routing rules, and admin auth in config.json" +icon: "shield-check" +--- + +The `governance` block lets you seed all governance resources directly in `config.json`. On startup, Bifrost loads these into the configuration store. This is the recommended approach for GitOps workflows where governance state is managed as code. + + +**Governance enforcement is always active** in OSS — you do not need a plugin entry to enable it. To require a virtual key on every inference request, set `client.enforce_auth_on_inference: true`. This is the global default, but a more specific inference-auth flag such as `governance.auth_config.disable_auth_on_inference` overrides it; if no specific override is set, `client.enforce_auth_on_inference` applies. + + +--- + +## Admin Authentication + +Protect the Bifrost dashboard and management API with username/password auth: + +```json +{ + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD", + "disable_auth_on_inference": false + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `is_enabled` | `false` | Enable admin username/password auth | +| `admin_username` | — | Admin username (supports `env.` prefix) | +| `admin_password` | — | Admin password (supports `env.` prefix) | +| `disable_auth_on_inference` | `false` | Skip auth check on `/v1/*` inference routes | + +--- + +## Virtual Keys + +Virtual keys are issued to clients and act as scoped API tokens. Each key specifies which providers, models, and API keys the bearer is allowed to use. + +```json +{ + "governance": { + "virtual_keys": [ + { + "id": "vk-team-platform", + "name": "platform-team", + "value": "env.VK_PLATFORM_TEAM", + "is_active": true, + "provider_configs": [ + { + "provider": "openai", + "allowed_models": ["gpt-4o", "gpt-4o-mini"], + "key_ids": ["*"], + "weight": 1 + }, + { + "provider": "anthropic", + "allowed_models": ["*"], + "key_ids": ["*"], + "weight": 1 + } + ] + } + ] + } +} +``` + +### Virtual Key Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique virtual key ID (referenced by budgets / rate limits) | +| `name` | Yes | Human-readable name | +| `value` | No | The key token sent by clients (use `env.` prefix). Auto-generated if omitted | +| `is_active` | No | Default `true`. Set `false` to disable without deleting | +| `team_id` | No | Associate with a team (mutually exclusive with `customer_id`) | +| `customer_id` | No | Associate with a customer | +| `rate_limit_id` | No | Attach a rate limit | +| `calendar_aligned` | No | Snap budget resets to day/week/month/year boundaries | +| `provider_configs` | No | Allowed provider/model/key combinations (empty = deny all) | + +### Provider Config Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `provider` | Yes | Provider name (e.g. `"openai"`) | +| `allowed_models` | No | Model allow-list. `["*"]` = all models; `[]` = deny all | +| `key_ids` | No | Provider key names allowed for this VK. `["*"]` = all keys; `[]` = deny all. Use key `name` values (not UUIDs) in `config.json` | +| `weight` | No | Load-balancing weight when multiple provider configs are present | +| `rate_limit_id` | No | Attach a per-provider-config rate limit | + +--- + +## Budgets + +Budgets cap cumulative spend (in USD) for a virtual key or provider config over a rolling window: + +```json +{ + "governance": { + "budgets": [ + { + "id": "budget-platform-monthly", + "max_limit": 500.00, + "reset_duration": "1M", + "virtual_key_id": "vk-team-platform" + } + ] + } +} +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique budget ID | +| `max_limit` | Yes | Maximum spend in USD | +| `reset_duration` | Yes | Window length: `"30s"`, `"5m"`, `"1h"`, `"1d"`, `"1w"`, `"1M"`, `"1Y"` | +| `virtual_key_id` | No | Attach to a virtual key (mutually exclusive with `provider_config_id`) | +| `provider_config_id` | No | Attach to a provider config ID | + +--- + +## Rate Limits + +Rate limits cap requests or tokens over a rolling window: + +```json +{ + "governance": { + "rate_limits": [ + { + "id": "rl-platform-hourly", + "request_max_limit": 1000, + "request_reset_duration": "1h", + "token_max_limit": 1000000, + "token_reset_duration": "1h" + } + ] + } +} +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique rate limit ID | +| `request_max_limit` | No | Maximum requests in window | +| `request_reset_duration` | No | Window for request counter | +| `token_max_limit` | No | Maximum tokens (input + output) in window | +| `token_reset_duration` | No | Window for token counter | + +Attach a rate limit to a virtual key via `virtual_keys[].rate_limit_id`, or to a provider config via `virtual_keys[].provider_configs[].rate_limit_id`. + +--- + +## Routing Rules + +Routing rules dynamically select the provider and model for each request based on a [CEL](https://cel.dev) expression. They are evaluated in priority order before the request is dispatched. + +```json +{ + "governance": { + "routing_rules": [ + { + "id": "route-gpt4-to-azure", + "name": "Redirect GPT-4o to Azure", + "cel_expression": "request.model == 'gpt-4o'", + "targets": [ + { "provider": "azure", "model": "gpt-4o", "weight": 1.0 } + ] + }, + { + "id": "route-cost-split", + "name": "Split traffic 70/30 between providers", + "cel_expression": "true", + "targets": [ + { "provider": "openai", "weight": 0.7 }, + { "provider": "anthropic", "weight": 0.3 } + ] + } + ] + } +} +``` + +### Rule Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique rule ID | +| `name` | Yes | Human-readable name | +| `cel_expression` | No | CEL expression. `"true"` matches every request | +| `targets` | Yes | Weighted target list (weights must sum to `1.0`) | +| `enabled` | No | Default `true` | +| `priority` | No | Evaluation order within scope — lower numbers run first | +| `scope` | No | `"global"` (default), `"team"`, `"customer"`, `"virtual_key"` | +| `scope_id` | Conditional | Required when `scope` is not `"global"` | +| `chain_rule` | No | If `true`, re-evaluates the chain after this rule matches | +| `fallbacks` | No | Ordered fallback provider list if primary target fails | + +### Target Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `weight` | Yes | Fraction of traffic (all weights in a rule must sum to `1.0`) | +| `provider` | No | Target provider. Omit to keep the incoming request's provider | +| `model` | No | Target model. Omit to keep the incoming request's model | +| `key_id` | No | Pin a specific API key by name | + +--- + +## Customers & Teams + +Define organizational entities and attach budgets or rate limits to them: + +```json +{ + "governance": { + "customers": [ + { + "id": "customer-acme", + "name": "Acme Corp", + "budget_id": "budget-acme-monthly", + "rate_limit_id": "rl-acme-hourly" + } + ], + "teams": [ + { + "id": "team-ml", + "name": "ML Team", + "customer_id": "customer-acme", + "budget_id": "budget-team-ml" + } + ] + } +} +``` + +--- + +## Full Governance Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "client": { + "enforce_auth_on_inference": true + }, + + "governance": { + "auth_config": { + "is_enabled": true, + "admin_username": "env.BIFROST_ADMIN_USERNAME", + "admin_password": "env.BIFROST_ADMIN_PASSWORD" + }, + + "budgets": [ + { + "id": "budget-platform", + "max_limit": 1000.00, + "reset_duration": "1M", + "virtual_key_id": "vk-platform" + } + ], + + "rate_limits": [ + { + "id": "rl-platform", + "request_max_limit": 5000, + "request_reset_duration": "1h", + "token_max_limit": 5000000, + "token_reset_duration": "1h" + } + ], + + "virtual_keys": [ + { + "id": "vk-platform", + "name": "platform-key", + "value": "env.VK_PLATFORM", + "is_active": true, + "rate_limit_id": "rl-platform", + "provider_configs": [ + { + "provider": "openai", + "allowed_models": ["*"], + "key_ids": ["*"], + "weight": 1 + } + ] + } + ], + + "routing_rules": [ + { + "id": "fallback-to-anthropic", + "name": "Fallback on error", + "cel_expression": "true", + "targets": [{ "provider": "openai", "weight": 1.0 }], + "fallbacks": ["anthropic"] + } + ] + }, + + "providers": { + "openai": { + "keys": [{ "name": "openai-primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "anthropic": { + "keys": [{ "name": "anthropic-primary", "value": "env.ANTHROPIC_API_KEY", "models": ["*"], "weight": 1.0 }] + } + }, + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost" + } + } +} +``` diff --git a/docs/deployment-guides/config-json/guardrails.mdx b/docs/deployment-guides/config-json/guardrails.mdx new file mode 100644 index 0000000000..f6258ca872 --- /dev/null +++ b/docs/deployment-guides/config-json/guardrails.mdx @@ -0,0 +1,291 @@ +--- +title: "Guardrails" +description: "Configure content moderation and policy enforcement in config.json using guardrails_config" +icon: "shield-halved" +--- + + +Guardrails are an **enterprise-only** feature and require the enterprise Bifrost image. + + +Guardrails are configured under `guardrails_config` in `config.json`. The configuration has two parts: + +- **`guardrail_providers`** — the backend that performs the check. Rules link to providers by `id`. +- **`guardrail_rules`** — CEL expressions that control when and where providers are invoked. + +--- + +## Providers + + + + +Runs entirely in-process with no external dependency. Patterns use RE2 syntax. Supports optional per-pattern flags: `i` (case-insensitive), `m` (multiline), `s` (dot-all). + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 1, + "provider_name": "regex", + "policy_name": "block-secrets", + "enabled": true, + "timeout": 5, + "config": { + "patterns": [ + { "pattern": "sk-[A-Za-z0-9]{20,}", "description": "OpenAI API key" }, + { "pattern": "AKIA[0-9A-Z]{16}", "description": "AWS access key" }, + { "pattern": "gh[ps]_[A-Za-z0-9]{36}", "description": "GitHub token", "flags": "i" } + ], + "mode": "block" + } + } + ] + } +} +``` + + + + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 2, + "provider_name": "bedrock", + "policy_name": "content-filter", + "enabled": true, + "timeout": 15, + "config": { + "guardrail_arn": "arn:aws:bedrock:us-east-1::guardrail/abc123", + "guardrail_version": "DRAFT", + "region": "us-east-1", + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY" + } + } + ] + } +} +``` + + + + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 3, + "provider_name": "azure", + "policy_name": "azure-content-safety", + "enabled": true, + "timeout": 10, + "config": { + "endpoint": "https://your-resource.cognitiveservices.azure.com", + "api_key": "env.AZURE_CONTENT_SAFETY_KEY", + "analyze_enabled": true, + "analyze_severity_threshold": "medium", + "jailbreak_shield_enabled": true, + "indirect_attack_shield_enabled": true, + "copyright_enabled": false, + "text_blocklist_enabled": false, + "blocklist_names": [] + } + } + ] + } +} +``` + +`analyze_severity_threshold` accepts `"low"`, `"medium"`, or `"high"`. + + + + +```json +{ + "guardrails_config": { + "guardrail_providers": [ + { + "id": 4, + "provider_name": "grayswan", + "policy_name": "grayswan-jailbreak", + "enabled": true, + "timeout": 15, + "config": { + "api_key": "env.GRAYSWAN_API_KEY", + "violation_threshold": 0.7, + "reasoning_mode": "standard", + "policy_id": "", + "policy_ids": [], + "rules": {} + } + } + ] + } +} +``` + + + + +### Provider Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique integer ID — referenced by rules via `provider_config_ids` | +| `provider_name` | Yes | Backend: `"regex"`, `"bedrock"`, `"azure"`, `"grayswan"` | +| `policy_name` | Yes | Human-readable policy label | +| `enabled` | Yes | `true` to activate | +| `timeout` | No | Execution timeout in seconds | +| `config` | No | Provider-specific configuration object | + +--- + +## Rules + +Rules are CEL expressions that fire when their condition matches. Available CEL variables: + +| Variable | Type | Description | +|----------|------|-------------| +| `model` | `string` | Model name from the request | +| `provider` | `string` | Provider name (e.g. `"openai"`) | +| `headers` | `map` | HTTP request headers | +| `params` | `map` | Query parameters | +| `customer` | `string` | Customer ID | +| `team` | `string` | Team ID | +| `user` | `string` | User ID | + +```json +{ + "guardrails_config": { + "guardrail_rules": [ + { + "id": 101, + "name": "block-secrets-input", + "description": "Block prompts containing credentials", + "enabled": true, + "cel_expression": "true", + "apply_to": "input", + "sampling_rate": 100, + "timeout": 10, + "provider_config_ids": [1] + }, + { + "id": 102, + "name": "content-safety-gpt4o-output", + "enabled": true, + "cel_expression": "model == 'gpt-4o'", + "apply_to": "output", + "sampling_rate": 100, + "timeout": 15, + "provider_config_ids": [3] + }, + { + "id": 103, + "name": "grayswan-openai-partial", + "enabled": true, + "cel_expression": "provider == 'openai'", + "apply_to": "input", + "sampling_rate": 50, + "timeout": 20, + "provider_config_ids": [4] + } + ] + } +} +``` + +### Rule Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique integer ID | +| `name` | Yes | Human-readable name | +| `description` | No | Optional description | +| `enabled` | Yes | `true` to activate | +| `cel_expression` | Yes | CEL boolean expression. `"true"` matches every request | +| `apply_to` | Yes | `"input"`, `"output"`, or `"both"` | +| `sampling_rate` | No | `0`–`100`; percentage of requests to evaluate (default: `100`) | +| `timeout` | No | Rule timeout in seconds | +| `provider_config_ids` | No | `id` values of providers to invoke when this rule matches. Multiple providers run in parallel | + +--- + +## Full Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "providers": { + "openai": { + "keys": [{ "name": "primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 }] + } + }, + + "guardrails_config": { + "guardrail_providers": [ + { + "id": 1, + "provider_name": "regex", + "policy_name": "block-secrets", + "enabled": true, + "timeout": 5, + "config": { + "patterns": [ + { "pattern": "sk-[A-Za-z0-9]{20,}", "description": "OpenAI API key" }, + { "pattern": "AKIA[0-9A-Z]{16}", "description": "AWS access key" } + ], + "mode": "block" + } + }, + { + "id": 2, + "provider_name": "azure", + "policy_name": "content-safety", + "enabled": true, + "timeout": 10, + "config": { + "endpoint": "https://your-resource.cognitiveservices.azure.com", + "api_key": "env.AZURE_CONTENT_SAFETY_KEY", + "analyze_enabled": true, + "analyze_severity_threshold": "medium", + "jailbreak_shield_enabled": true, + "indirect_attack_shield_enabled": false + } + } + ], + "guardrail_rules": [ + { + "id": 101, + "name": "block-secrets-input", + "description": "Block prompts leaking credentials", + "enabled": true, + "cel_expression": "true", + "apply_to": "input", + "sampling_rate": 100, + "timeout": 10, + "provider_config_ids": [1] + }, + { + "id": 102, + "name": "content-safety-both", + "description": "Azure content safety on all traffic", + "enabled": true, + "cel_expression": "true", + "apply_to": "both", + "sampling_rate": 100, + "timeout": 15, + "provider_config_ids": [2] + } + ] + } +} +``` diff --git a/docs/deployment-guides/config-json/plugins.mdx b/docs/deployment-guides/config-json/plugins.mdx new file mode 100644 index 0000000000..847f290e02 --- /dev/null +++ b/docs/deployment-guides/config-json/plugins.mdx @@ -0,0 +1,318 @@ +--- +title: "Plugins" +description: "Configure Bifrost plugins in config.json — semantic cache, OpenTelemetry, Maxim, Datadog, and custom plugins" +icon: "puzzle-piece" +--- + + +**The `plugins` array only controls explicitly opt-in plugins**: `semantic_cache`, `otel`, `maxim`, `datadog` (enterprise), and custom plugins. + +**Telemetry, logging, and governance are auto-loaded built-ins** — they are always active and configured via the `client` block and dedicated top-level keys, not the `plugins` array. + + +--- + +## Auto-Loaded Built-ins + +These plugins start automatically. You do **not** add them to the `plugins` array. + +| Plugin | Always active? | How to configure | +|--------|---------------|-----------------| +| **Telemetry** (Prometheus `/metrics`) | Yes, always | `client.prometheus_labels` for custom labels; push gateway via `plugins` entry once DB-backed mode is running | +| **Logging** | When `client.enable_logging: true` and `logs_store` is configured | `client.enable_logging`, `client.disable_content_logging`, `client.logging_headers` | +| **Governance** | Yes, always (OSS) | `client.enforce_auth_on_inference` for VK enforcement; `governance.*` for virtual keys / budgets / routing rules | + +See [Client Configuration](/deployment-guides/config-json/client) and [Governance](/deployment-guides/config-json/governance) for full details. + +--- + +## Plugin Array Structure + +Every entry in the `plugins` array supports these common fields: + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `name` | string | Yes | Plugin name | +| `enabled` | boolean | Yes | Enable or disable this plugin | +| `config` | object | Varies | Plugin-specific configuration | +| `path` | string | No | Path to a custom plugin binary or WASM file | +| `version` | integer | No | 🛑 **DB-Backed Only.** Plugin metadata persisted on `TablePlugin` rather than `PluginConfig`. Ignored in `config.json`. Used in UI/DB workflows to force refresh/reload. | +| `placement` | string | No | 🛑 **DB-Backed Only.** Execution metadata (`"pre_builtin"`, `"builtin"`, `"post_builtin"`) persisted on `TablePlugin`. Ignored in `config.json`. Relevant for dynamic plugin ordering in UI/DB mode. | +| `order` | integer | No | 🛑 **DB-Backed Only.** Execution metadata persisted on `TablePlugin`. Ignored in `config.json`. Within a placement group, lower values run earlier. | + + +`name`, `enabled`, `path`, and `config` are the core plugin config fields parsed from `config.json`. `version`, `placement`, and `order` are **not valid `config.json` keys**; they are DB-backed metadata persisted on `TablePlugin` and are only applicable when managing plugins dynamically via the UI or Database. + + +--- + + + + + +### Semantic Cache + +Caches LLM responses by semantic similarity. Returns a cached response when an incoming request is semantically close enough to a previous one. + +Requires a [vector store](/deployment-guides/config-json/storage#vector_store) to be configured. + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `config.dimension` | Yes | — | Embedding dimension. Use `1` for hash-based (exact) caching without an embedding provider | +| `config.provider` | No | — | Provider for generating embeddings (required for semantic mode) | +| `config.embedding_model` | No | — | Model for embeddings (required when `provider` is set) | +| `config.threshold` | No | `0.8` | Cosine similarity threshold for a cache hit (0.0–1.0) | +| `config.ttl` | No | `300` | Cache entry TTL in seconds (or a duration string like `"1h"`) | +| `config.cache_by_model` | No | `true` | Include model in cache key | +| `config.cache_by_provider` | No | `true` | Include provider in cache key | +| `config.exclude_system_prompt` | No | `false` | Exclude system prompt from cache key | +| `config.conversation_history_threshold` | No | `3` | Skip caching for requests with more messages than this | +| `config.default_cache_key` | No | — | Default cache key when no `x-bf-cache-key` header is sent | + +**Semantic mode** (embedding-based similarity search): + +```json +{ + "plugins": [ + { + "name": "semantic_cache", + "enabled": true, + "config": { + "provider": "openai", + "embedding_model": "text-embedding-3-small", + "dimension": 1536, + "threshold": 0.85, + "ttl": 300, + "cache_by_model": true, + "cache_by_provider": true + } + } + ] +} +``` + +**Hash mode** (exact-match caching, no embedding provider needed): + +```json +{ + "plugins": [ + { + "name": "semantic_cache", + "enabled": true, + "config": { + "dimension": 1, + "ttl": 1800 + } + } + ] +} +``` + + +You must also configure a `vector_store` in `config.json`. See [Storage — vector_store](/deployment-guides/config-json/storage#vector_store). + + + + + + +### OpenTelemetry (OTel) + +Exports distributed traces to any OTel-compatible collector (Jaeger, Zipkin, Tempo, Datadog via OTLP, etc.). + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `config.collector_url` | Yes | — | OTLP collector endpoint | +| `config.trace_type` | Yes | — | Trace format: `"genai_extension"`, `"vercel"`, or `"open_inference"` | +| `config.protocol` | Yes | — | `"http"` or `"grpc"` | +| `config.service_name` | No | `"bifrost"` | Service name reported to the collector | +| `config.metrics_enabled` | No | `false` | Enable push-based OTLP metrics export | +| `config.metrics_endpoint` | No | — | OTLP metrics endpoint URL | +| `config.metrics_push_interval` | No | `15` | Metrics push interval in seconds | +| `config.headers` | No | — | Custom headers for the collector (supports `env.` prefix) | +| `config.insecure` | No | `false` | Skip TLS verification | +| `config.tls_ca_cert` | No | — | Path to TLS CA certificate | + +```json +{ + "plugins": [ + { + "name": "otel", + "enabled": true, + "config": { + "collector_url": "http://otel-collector:4318", + "trace_type": "genai_extension", + "protocol": "http", + "service_name": "bifrost-gateway" + } + } + ] +} +``` + +**With authentication headers:** + +```json +{ + "plugins": [ + { + "name": "otel", + "enabled": true, + "config": { + "collector_url": "https://otel.example.com:4318", + "trace_type": "open_inference", + "protocol": "http", + "service_name": "bifrost", + "headers": { + "Authorization": "env.OTEL_AUTH_HEADER" + } + } + } + ] +} +``` + +**With OTLP metrics export:** + +```json +{ + "plugins": [ + { + "name": "otel", + "enabled": true, + "config": { + "collector_url": "http://otel-collector:4318", + "trace_type": "genai_extension", + "protocol": "http", + "metrics_enabled": true, + "metrics_endpoint": "http://otel-collector:4318/v1/metrics", + "metrics_push_interval": 30 + } + } + ] +} +``` + + + + + +### Maxim Observability + +Sends request traces to the [Maxim](https://www.getmaxim.ai) observability platform. + +| Field | Required | Description | +|-------|----------|-------------| +| `config.api_key` | Yes | Maxim API key (use `env.` prefix) | +| `config.log_repo_id` | No | Default Maxim logger repository ID | + +```json +{ + "plugins": [ + { + "name": "maxim", + "enabled": true, + "config": { + "api_key": "env.MAXIM_API_KEY", + "log_repo_id": "your-log-repo-id" + } + } + ] +} +``` + + + + + +### Datadog + + +Datadog is an **enterprise-only** plugin and is silently ignored in OSS builds. + + +Sends APM traces and metrics to a Datadog Agent. + +| Field | Default | Description | +|-------|---------|-------------| +| `config.agent_addr` | `"localhost:8126"` | Datadog Agent address for APM traces | +| `config.service_name` | `"bifrost"` | Service name in Datadog | +| `config.env` | — | Environment tag (e.g. `"production"`, `"staging"`) | +| `config.version` | — | Service version tag | +| `config.enable_traces` | `true` | Enable APM trace collection | +| `config.custom_tags` | `{}` | Additional key/value tags for all traces and metrics | + +```json +{ + "plugins": [ + { + "name": "datadog", + "enabled": true, + "config": { + "agent_addr": "datadog-agent:8126", + "service_name": "bifrost", + "env": "production", + "enable_traces": true, + "custom_tags": { + "team": "platform", + "region": "us-east-1" + } + } + } + ] +} +``` + + + + + +--- + +## Custom / Dynamic Plugins + +Load a custom Go plugin binary or WASM plugin at startup using the `path` field. Custom plugins must implement one of the Bifrost plugin interfaces. + +```json +{ + "plugins": [ + { + "name": "my-custom-auth", + "enabled": true, + "path": "/app/plugins/my-custom-auth.so", + "config": { + "auth_endpoint": "env.AUTH_SERVICE_URL" + } + } + ] +} +``` + +**WASM plugin:** + +```json +{ + "plugins": [ + { + "name": "my-wasm-plugin", + "enabled": true, + "path": "/app/plugins/my-plugin.wasm", + "config": {} + } + ] +} +``` + +See [Writing Go Plugins](/plugins/writing-go-plugin) and [Writing WASM Plugins](/plugins/writing-wasm-plugin) for implementation guides. + +**Placement and ordering (DB-backed only):** + +When creating plugins dynamically via the DB/UI (rather than `config.json`), you can specify their execution order: + +| `placement` | When it runs | +|-------------|-------------| +| `pre_builtin` | Before all built-in plugins | +| `builtin` | Alongside built-in plugins (by `order`) | +| `post_builtin` | After all built-in plugins (default) | + +Within a placement group, lower `order` values run earlier. diff --git a/docs/deployment-guides/config-json/providers.mdx b/docs/deployment-guides/config-json/providers.mdx new file mode 100644 index 0000000000..ca07e0e5f8 --- /dev/null +++ b/docs/deployment-guides/config-json/providers.mdx @@ -0,0 +1,755 @@ +--- +title: "Provider Setup" +description: "Configure LLM providers in config.json — API keys, cloud-native auth, per-provider network settings, and self-hosted endpoints" +icon: "plug" +--- + +All providers are configured under `providers` in `config.json`. Each provider entry contains a `keys` array where every key has a `name`, `value`, `models`, and `weight`, plus optional provider-specific config objects. + +**Supplying credentials:** + +Use the `env.` prefix to reference environment variables — never put API keys directly in `config.json`: + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "name": "primary", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + } + } +} +``` + +--- + +## Common Provider Fields + +Every key object supports these fields: + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Unique name for this key (used in logs and virtual key pin) | +| `value` | string | API key value or `env.VAR_NAME` reference | +| `models` | array | Models this key serves. `["*"]` = all models | +| `weight` | float | Load balancing weight. Higher = more traffic | +| `aliases` | object | Map logical name → actual model name for this key | +| `use_for_batch_api` | boolean | Mark key as eligible for batch API calls | + +Per-provider `network_config` options (applies to all standard providers): + +| Field | Type | Description | +|-------|------|-------------| +| `default_request_timeout_in_seconds` | integer | Per-request timeout | +| `max_retries` | integer | Retry attempts on transient errors | +| `retry_backoff_initial` | integer | Initial backoff in milliseconds | +| `retry_backoff_max` | integer | Maximum backoff in milliseconds | +| `max_conns_per_host` | integer | Max TCP connections to the provider endpoint (default: 5000) | +| `extra_headers` | object | Static headers added to every provider request | +| `stream_idle_timeout_in_seconds` | integer | Idle timeout per stream chunk (default: 60) | +| `insecure_skip_verify` | boolean | Disable TLS verification (last resort only) | +| `ca_cert_pem` | string | PEM-encoded CA for self-signed or private CA endpoints | + +Concurrency and buffering per provider: + +| Field | Type | Description | +|-------|------|-------------| +| `concurrency_and_buffer_size.concurrency` | integer | Max concurrent requests to this provider | +| `concurrency_and_buffer_size.buffer_size` | integer | Request queue depth | + +--- + + + + + +### OpenAI + +Supports multiple keys with weighted load balancing. Mark one key with `use_for_batch_api: true` to designate it for the Batch API. + +```json +{ + "providers": { + "openai": { + "keys": [ + { + "name": "openai-primary", + "value": "env.OPENAI_KEY_1", + "models": ["*"], + "weight": 2.0 + }, + { + "name": "openai-secondary", + "value": "env.OPENAI_KEY_2", + "models": ["gpt-4o-mini"], + "weight": 1.0 + }, + { + "name": "openai-batch", + "value": "env.OPENAI_KEY_BATCH", + "models": ["*"], + "weight": 1.0, + "use_for_batch_api": true + } + ], + "network_config": { + "default_request_timeout_in_seconds": 120, + "max_retries": 3, + "retry_backoff_initial": 500, + "retry_backoff_max": 5000 + } + } + } +} +``` + + + + + +### Anthropic + +```json +{ + "providers": { + "anthropic": { + "keys": [ + { + "name": "anthropic-primary", + "value": "env.ANTHROPIC_KEY_1", + "models": ["*"], + "weight": 1.0 + }, + { + "name": "anthropic-secondary", + "value": "env.ANTHROPIC_KEY_2", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "default_request_timeout_in_seconds": 180 + } + } + } +} +``` + +**Override Anthropic beta headers** (optional): + +```json +{ + "providers": { + "anthropic": { + "keys": [ + { + "name": "primary", + "value": "env.ANTHROPIC_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "beta_header_overrides": { + "redact-thinking-": true + } + } + } + } +} +``` + + + + + +### Azure OpenAI + +Azure requires `azure_key_config` on every key with `endpoint` and `api_version`. List your Azure deployment names in `models` — Bifrost routes requests using the model name as the deployment name. If your deployment names differ from the model names you use in requests, add an `aliases` map on the key. + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-primary", + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + +Set environment variables: + +```bash +export AZURE_API_KEY="your-azure-api-key" +export AZURE_ENDPOINT="https://your-resource.openai.azure.com" +export AZURE_API_VERSION="2024-10-21" +``` + + + + +When `value` is empty or omitted, Bifrost uses `DefaultAzureCredential` — which resolves credentials from Workload Identity, VM managed identity, or `az login`. + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-workload-identity", + "value": "", + "models": ["gpt-4o"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + + + + +**Deployment name aliases** — when your Azure deployment names differ from the model names in requests, use `aliases`: + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-primary", + "value": "env.AZURE_API_KEY", + "models": ["gpt-4o"], + "weight": 1.0, + "aliases": { + "gpt-4o": "gpt-4o-prod-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + +**Multi-region failover** (two keys, different regions): + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "eastus", + "value": "env.AZURE_KEY_EAST", + "models": ["gpt-4o"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT_EAST", + "api_version": "env.AZURE_API_VERSION" + } + }, + { + "name": "westus", + "value": "env.AZURE_KEY_WEST", + "models": ["gpt-4o"], + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT_WEST", + "api_version": "env.AZURE_API_VERSION" + } + } + ] + } + } +} +``` + + + + + +### AWS Bedrock + +Bedrock requires `bedrock_key_config` with at minimum a `region`. Three auth modes: + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-static", + "value": "", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1", + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY" + } + } + ] + } + } +} +``` + + + + +When only `region` is set, Bifrost inherits credentials from the AWS SDK default chain — IRSA (IAM Roles for Service Accounts), EC2 instance profile, or `AWS_*` env vars. + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-iam", + "value": "", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-assumerole", + "value": "", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-west-2", + "role_arn": "env.AWS_ROLE_ARN", + "external_id": "env.AWS_EXTERNAL_ID", + "session_name": "bifrost-session" + } + } + ] + } + } +} +``` + + + + +**Model aliases** (map logical names to Bedrock inference profile IDs): + +```json +{ + "bedrock_key_config": { + "region": "us-east-1" + }, + "aliases": { + "claude-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "claude-haiku": "us.anthropic.claude-3-5-haiku-20241022-v1:0" + } +} +``` + +**Batch API — S3 configuration:** + +```json +{ + "bedrock_key_config": { + "region": "us-east-1", + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "batch_s3_config": { + "buckets": [ + { + "bucket_name": "my-bedrock-batch-bucket", + "prefix": "batch/", + "is_default": true + } + ] + } + } +} +``` + + + + + +### Google Vertex AI + +Vertex requires `vertex_key_config` with `project_id` and `region`. Two auth modes: + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-sa", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_AUTH_CREDENTIALS" + } + } + ] + } + } +} +``` + +`VERTEX_AUTH_CREDENTIALS` should contain the base64-encoded service account JSON. + + + + +When `auth_credentials` is omitted, Bifrost calls `google.FindDefaultCredentials` — which resolves to GKE Workload Identity, GCE metadata server, or `gcloud auth application-default login`. + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-workload-identity", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "my-gcp-project", + "region": "us-central1" + } + } + ] + } + } +} +``` + + + + + + + + +### Standard API-Key Providers + +These providers follow the same simple pattern — one or more keys with weights. Replace the provider name and env var name accordingly. + +```json +{ + "providers": { + "groq": { + "keys": [ + { + "name": "groq-primary", + "value": "env.GROQ_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "gemini": { + "keys": [ + { + "name": "gemini-primary", + "value": "env.GEMINI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "mistral": { + "keys": [ + { + "name": "mistral-primary", + "value": "env.MISTRAL_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "cohere": { + "keys": [{ "name": "cohere-main", "value": "env.COHERE_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "perplexity": { + "keys": [{ "name": "perplexity-main", "value": "env.PERPLEXITY_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "xai": { + "keys": [{ "name": "xai-main", "value": "env.XAI_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "cerebras": { + "keys": [{ "name": "cerebras-main", "value": "env.CEREBRAS_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "openrouter": { + "keys": [{ "name": "openrouter-main", "value": "env.OPENROUTER_API_KEY", "models": ["*"], "weight": 1.0 }] + }, + "nebius": { + "keys": [{ "name": "nebius-main", "value": "env.NEBIUS_API_KEY", "models": ["*"], "weight": 1.0 }] + } + } +} +``` + + + + + +### Self-Hosted Providers + +Self-hosted providers point to a URL you operate. No API key is typically required (`"value": ""`). + + + + +```json +{ + "providers": { + "ollama": { + "keys": [ + { + "name": "ollama-local", + "value": "", + "models": ["*"], + "weight": 1.0, + "ollama_key_config": { + "url": "http://localhost:11434" + } + } + ] + } + } +} +``` + +Using an env var for the URL (useful across environments): + +```json +{ + "ollama_key_config": { + "url": "env.OLLAMA_URL" + } +} +``` + + + + +vLLM instances are model-specific — one key per served model: + +```json +{ + "providers": { + "vllm": { + "keys": [ + { + "name": "vllm-llama3-70b", + "value": "", + "models": ["llama-3-70b"], + "weight": 1.0, + "vllm_key_config": { + "url": "http://vllm-server:8000", + "model_name": "meta-llama/Meta-Llama-3-70B-Instruct" + } + }, + { + "name": "vllm-mistral", + "value": "", + "models": ["mistral-7b"], + "weight": 1.0, + "vllm_key_config": { + "url": "http://vllm-mistral:8000", + "model_name": "mistralai/Mistral-7B-Instruct-v0.3" + } + } + ] + } + } +} +``` + + + + +```json +{ + "providers": { + "sgl": { + "keys": [ + { + "name": "sgl-main", + "value": "", + "models": ["*"], + "weight": 1.0, + "sgl_key_config": { + "url": "http://sgl-router:30000" + } + } + ] + } + } +} +``` + + + + +These providers use `aliases` to map logical model names to provider-specific IDs: + +```json +{ + "providers": { + "huggingface": { + "keys": [ + { + "name": "hf-main", + "value": "env.HF_API_KEY", + "models": ["llama-3", "mixtral"], + "weight": 1.0, + "aliases": { + "llama-3": "meta-llama/Meta-Llama-3-8B-Instruct", + "mixtral": "mistralai/Mixtral-8x7B-Instruct-v0.1" + } + } + ] + }, + "replicate": { + "keys": [ + { + "name": "replicate-main", + "value": "env.REPLICATE_API_KEY", + "models": ["llama-3"], + "weight": 1.0, + "aliases": { + "llama-3": "meta/meta-llama-3-70b-instruct" + }, + "replicate_key_config": { + "use_deployments_endpoint": false + } + } + ] + } + } +} +``` + + + + + + + + +--- + +## Proxy Configuration + +Route provider traffic through an HTTP or SOCKS5 proxy: + +```json +{ + "providers": { + "openai": { + "keys": [ + { "name": "primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 } + ], + "proxy_config": { + "type": "http", + "url": "http://proxy.corp.example.com:3128", + "username": "env.PROXY_USER", + "password": "env.PROXY_PASS" + } + } + } +} +``` + +| Field | Type | Options | +|-------|------|---------| +| `proxy_config.type` | string | `"none"`, `"http"`, `"socks5"`, `"environment"` | +| `proxy_config.url` | string | Proxy server URL | +| `proxy_config.username` | string | Proxy auth username | +| `proxy_config.password` | string | Proxy auth password (`env.` supported) | +| `proxy_config.ca_cert_pem` | string | PEM CA for TLS-intercepting proxies | + +Use `"type": "environment"` to pick up `HTTP_PROXY` / `HTTPS_PROXY` env vars automatically. + +--- + +## Multi-Provider Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "providers": { + "openai": { + "keys": [ + { "name": "openai-primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 2.0 } + ] + }, + "anthropic": { + "keys": [ + { "name": "anthropic-primary", "value": "env.ANTHROPIC_API_KEY", "models": ["*"], "weight": 1.0 } + ] + }, + "groq": { + "keys": [ + { "name": "groq-primary", "value": "env.GROQ_API_KEY", "models": ["*"], "weight": 1.0 } + ] + } + } +} +``` + +With three providers and the weights above, traffic is distributed: 50% OpenAI, 25% Anthropic, 25% Groq. If any provider returns an error, Bifrost automatically retries on the next key or provider. diff --git a/docs/deployment-guides/config-json/schema-reference.mdx b/docs/deployment-guides/config-json/schema-reference.mdx new file mode 100644 index 0000000000..45b9b826ce --- /dev/null +++ b/docs/deployment-guides/config-json/schema-reference.mdx @@ -0,0 +1,202 @@ +--- +title: "Schema Reference" +description: "All top-level keys available in config.json, their types, and where each is documented" +icon: "brackets-curly" +--- + + +The live schema is published at [`https://www.getbifrost.ai/schema`](https://www.getbifrost.ai/schema). Add `"$schema": "https://www.getbifrost.ai/schema"` to your `config.json` for IDE autocomplete and inline validation. + + +This page is a concise reference for every top-level key in `config.json`. Click the **Guide** links for full field-by-field documentation. + +--- + +## Top-Level Keys + +| Key | Type | Description | Guide | +|-----|------|-------------|-------| +| `$schema` | string | Schema URL for IDE validation. Set to `"https://www.getbifrost.ai/schema"` | — | +| `encryption_key` | string | AES-256 key (derived via Argon2id). Accepts `env.VAR` prefix. Also read from `BIFROST_ENCRYPTION_KEY` env var | [Client](/deployment-guides/config-json/client#encryption-key) | +| `client` | object | Worker pool, logging, CORS, auth enforcement, header filtering, MCP, compat shims | [Client](/deployment-guides/config-json/client) | +| `providers` | object | LLM provider API keys, network settings, concurrency | [Providers](/deployment-guides/config-json/providers) | +| `governance` | object | Admin auth, virtual keys, budgets, rate limits, routing rules, customers, teams | [Governance](/deployment-guides/config-json/governance) | +| `guardrails_config` | object | Content moderation providers and CEL-based rules *(enterprise only)* | [Guardrails](/deployment-guides/config-json/guardrails) | +| `config_store` | object | Configuration database backend — SQLite, PostgreSQL, or disabled (file-only mode) | [Storage](/deployment-guides/config-json/storage#config_store) | +| `logs_store` | object | Request/response log database — SQLite, PostgreSQL + optional S3/GCS offload | [Storage](/deployment-guides/config-json/storage#logs_store) | +| `vector_store` | object | Vector database for semantic cache — Weaviate, Redis, Qdrant, Pinecone, Valkey | [Storage](/deployment-guides/config-json/storage#vector_store) | +| `plugins` | array | Opt-in plugins: `semantic_cache`, `otel`, `maxim`, `datadog`, custom | [Plugins](/deployment-guides/config-json/plugins) | +| `framework` | object | Model pricing catalog URL and sync interval | [Framework](#framework) | +| `mcp` | object | MCP server and tool configuration | — | +| `websocket` | object | WebSocket / Realtime API connection pool tuning | [WebSocket](#websocket) | +| `auth_config` | object | **Deprecated** — use `governance.auth_config` | [Client](/deployment-guides/config-json/client#authentication) | + +--- + +## `version` + +Controls how empty arrays in allow-list fields (`models`, `allowed_models`, `key_ids`, `tools_to_execute`) are interpreted: + +| Value | Behaviour | +|-------|-----------| +| `2` *(default, v1.5.0+)* | Empty array = **deny all**; `["*"]` = allow all | +| `1` *(v1.4.x compat)* | Empty array = **allow all** | + +Omitting `version` uses v2 semantics. Set `"version": 1` only if you are migrating from v1.4.x and need the old behaviour temporarily. + +--- + +## `client` + +Controls the worker pool, logging pipeline, security, and SDK shims. All fields are optional. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `initial_pool_size` | integer | `300` | Pre-allocated goroutines per provider queue | +| `drop_excess_requests` | boolean | `false` | Return HTTP 429 when queue is full | +| `enable_logging` | boolean | `true`* | Persist request/response logs (`*` auto-enabled when `logs_store` is set) | +| `disable_content_logging` | boolean | `false` | Strip message content from logs | +| `log_retention_days` | integer | `365` | Days to retain log entries | +| `logging_headers` | array | `[]` | HTTP headers to capture in log metadata | +| `enforce_auth_on_inference` | boolean | `false` | Require a virtual key on every `/v1/*` request | +| `allow_direct_keys` | boolean | `false` | Allow callers to pass provider API keys directly | +| `allowed_origins` | array | `["*"]` | CORS allowed origins | +| `max_request_body_size_mb` | integer | `100` | Maximum request body in MB | +| `whitelisted_routes` | array | `[]` | Routes that bypass auth middleware | +| `allowed_headers` | array | `[]` | Additional headers permitted for CORS/WebSocket | +| `required_headers` | array | `[]` | Headers that must be present on every request | +| `header_filter_config` | object | — | `allowlist` / `denylist` for `x-bf-eh-*` forwarded headers | +| `prometheus_labels` | array | `[]` | Custom labels for all Prometheus metrics | +| `compat` | object | — | SDK compatibility shims (`should_drop_params`, `convert_text_to_chat`, etc.) | +| `mcp_agent_depth` | integer | `10` | Max tool-call recursion depth | +| `mcp_tool_execution_timeout` | integer | `30` | Per-tool execution timeout in seconds | +| `mcp_tool_sync_interval` | integer | `10` | Tool sync interval in minutes (`0` = disabled) | +| `mcp_disable_auto_tool_inject` | boolean | `false` | Disable automatic MCP tool injection | +| `async_job_result_ttl` | integer | `3600` | TTL for async job results in seconds | +| `disable_db_pings_in_health` | boolean | `false` | Exclude DB connectivity from `/health` | +| `routing_chain_max_depth` | integer | `10` | Max routing rule chain evaluation depth | + +Full documentation: [Client Configuration](/deployment-guides/config-json/client). + +--- + +## `providers` + +Keyed by provider name. Each entry contains a `keys` array and optional `network_config`, `concurrency_and_buffer_size`, `proxy_config`. + +Supported provider keys: `openai`, `anthropic`, `azure`, `bedrock`, `vertex`, `gemini`, `mistral`, `groq`, `cohere`, `perplexity`, `xai`, `cerebras`, `openrouter`, `nebius`, `fireworks`, `parasail`, `huggingface`, `replicate`, `ollama`, `vllm`, `sgl`, `elevenlabs`, `runway`. + +Full documentation: [Provider Setup](/deployment-guides/config-json/providers). + +--- + +## `governance` + +Seeds governance resources at startup. All sub-keys are optional arrays. + +| Sub-key | Description | +|---------|-------------| +| `auth_config` | Admin username/password auth for the dashboard | +| `virtual_keys` | Scoped API tokens with provider/model allowlists | +| `budgets` | Spend caps in USD over a rolling window | +| `rate_limits` | Request and token rate limits | +| `customers` | Customer entities (attach budgets/rate limits) | +| `teams` | Team entities (attach to customers, budgets, rate limits) | +| `routing_rules` | CEL-based dynamic provider/model routing | +| `pricing_overrides` | Scoped per-model pricing overrides | +| `model_configs` | Per-model rate limit and budget configurations | + +Full documentation: [Governance](/deployment-guides/config-json/governance). + +--- + +## `guardrails_config` + +Enterprise-only. Two sub-keys: `guardrail_providers` (array) and `guardrail_rules` (array). + +Full documentation: [Guardrails](/deployment-guides/config-json/guardrails). + +--- + +## `config_store`, `logs_store`, `vector_store` + +Storage backends. Each has `enabled` (boolean), `type` (string), and `config` (object). + +| Store | Types | +|-------|-------| +| `config_store` | `"sqlite"`, `"postgres"` | +| `logs_store` | `"sqlite"`, `"postgres"` (+ optional `object_storage`) | +| `vector_store` | `"weaviate"`, `"redis"`, `"qdrant"`, `"pinecone"` (`"redis"` also covers Valkey-compatible endpoints) | + +Full documentation: [Storage](/deployment-guides/config-json/storage). + +--- + +## `framework` + +Controls model pricing catalog sync: + +```json +{ + "framework": { + "pricing": { + "pricing_url": "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json", + "pricing_sync_interval": 86400 + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `pricing.pricing_url` | LiteLLM catalog | URL of a model pricing JSON file | +| `pricing.pricing_sync_interval` | `86400` | Sync interval in seconds (minimum: `3600`) | + +--- + +## `websocket` + +Optional tuning for the WebSocket gateway (Responses API WebSocket mode, Realtime API). WebSocket is always enabled. + +```json +{ + "websocket": { + "max_connections_per_user": 100, + "transcript_buffer_size": 100, + "pool": { + "max_idle_per_key": 50, + "max_total_connections": 1000, + "idle_timeout_seconds": 600, + "max_connection_lifetime_seconds": 7200 + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `max_connections_per_user` | `100` | Max concurrent WebSocket connections per user | +| `transcript_buffer_size` | `100` | Transcript entries buffered for Realtime API mid-session fallback | +| `pool.max_idle_per_key` | `50` | Max idle upstream connections per provider/key | +| `pool.max_total_connections` | `1000` | Max total idle upstream connections | +| `pool.idle_timeout_seconds` | `600` | Evict idle connections after this many seconds | +| `pool.max_connection_lifetime_seconds` | `7200` | Max lifetime of any upstream connection | + +--- + +## Minimal Valid Config + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + "providers": { + "openai": { + "keys": [ + { "name": "primary", "value": "env.OPENAI_API_KEY", "models": ["*"], "weight": 1.0 } + ] + } + }, + "config_store": { "enabled": false } +} +``` diff --git a/docs/deployment-guides/config-json/storage.mdx b/docs/deployment-guides/config-json/storage.mdx new file mode 100644 index 0000000000..fd4bdbde97 --- /dev/null +++ b/docs/deployment-guides/config-json/storage.mdx @@ -0,0 +1,540 @@ +--- +title: "Storage" +description: "Configure Bifrost storage backends in config.json — config_store, logs_store, vector_store, and object storage for logs" +icon: "database" +--- + +Bifrost persists two types of data — **config** (providers, virtual keys, governance rules) and **logs** (request/response records). Each has its own store. A **vector store** is required for semantic caching. + +| Store | Purpose | Backends | +|-------|---------|---------| +| `config_store` | Provider configs, virtual keys, governance rules | SQLite, PostgreSQL | +| `logs_store` | Request/response logs shown in UI | SQLite, PostgreSQL + optional S3/GCS offload | +| `vector_store` | Semantic response caching | Weaviate, Redis, Valkey, Qdrant, Pinecone | + + +If you use PostgreSQL for any store, the target database must be **UTF8 encoded**. See [PostgreSQL UTF8 Requirement](/quickstart/gateway/setting-up#postgresql-utf8-requirement). + + +--- + +## config_store + + +When `config_store` is disabled (or absent), all configuration is loaded from `config.json` at startup only — the Web UI is disabled and changes require a restart. See [Two Configuration Modes](/deployment-guides/config-json#two-configuration-modes). + + + + + + +### SQLite (Default) + +Simplest setup — no external database required. Bifrost stores configuration in a local SQLite file. + +```json +{ + "config_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./config.db" + } + } +} +``` + +| Field | Description | +|-------|-------------| +| `config.path` | Path to the SQLite file (relative to app-dir, or absolute) | + + + + + +### PostgreSQL + +Production-grade storage suitable for high-availability and high-throughput deployments. + +```json +{ + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 5, + "max_open_conns": 50 + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `host` | — | PostgreSQL host (supports `env.` prefix) | +| `port` | — | PostgreSQL port (as string) | +| `user` | — | Database user (supports `env.` prefix) | +| `password` | — | Database password (supports `env.` prefix). Leave empty for IAM role auth. | +| `db_name` | — | Database name | +| `ssl_mode` | — | `"disable"`, `"require"`, `"verify-ca"`, `"verify-full"` | +| `max_idle_conns` | `5` | Maximum idle connections in the pool | +| `max_open_conns` | `50` | Maximum open connections to the database | + + + + + +### Disabled (file-only mode) + +Use this when you want Bifrost to read all configuration from `config.json` only — no database, no Web UI. + +```json +{ + "config_store": { + "enabled": false + } +} +``` + +This is the recommended setup for [multinode OSS deployments](/deployment-guides/how-to/multinode) where a shared `config.json` is the single source of truth. + + + + + +--- + +## logs_store + + + + + +### SQLite + +```json +{ + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./logs.db" + } + } +} +``` + + + + + +### PostgreSQL + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 10, + "max_open_conns": 100 + } + } +} +``` + +For high log volumes, increase `max_open_conns`: + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 10, + "max_open_conns": 200 + }, + "retention_days": 90 + } +} +``` + + + + + +```json +{ + "logs_store": { + "enabled": false + } +} +``` + + + + + +### Log Retention + +Set `retention_days` to automatically purge old log entries. `0` disables retention-based cleanup. + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { "...": "..." }, + "retention_days": 90 + } +} +``` + +### Object Storage for Logs + +Offload large request/response payloads from the database to S3 or GCS. The database retains only lightweight index records; payloads are fetched on demand. + + + + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { "...": "..." }, + "object_storage": { + "type": "s3", + "bucket": "env.S3_BUCKET", + "prefix": "bifrost", + "compress": true, + "region": "us-east-1", + "access_key_id": "env.S3_ACCESS_KEY_ID", + "secret_access_key": "env.S3_SECRET_ACCESS_KEY" + } + } +} +``` + +**IAM role (instance profile / IRSA)** — omit `access_key_id` and `secret_access_key`: + +```json +{ + "object_storage": { + "type": "s3", + "bucket": "bifrost-logs", + "region": "us-east-1", + "compress": true, + "role_arn": "arn:aws:iam::123456789012:role/BifrostS3Role" + } +} +``` + +| Field | Description | +|-------|-------------| +| `bucket` | S3 bucket name (supports `env.` prefix) | +| `prefix` | Key prefix for stored objects (default: `"bifrost"`) | +| `compress` | Enable gzip compression (default: `false`) | +| `region` | AWS region | +| `access_key_id` | AWS access key ID (omit for default credential chain) | +| `secret_access_key` | AWS secret access key | +| `session_token` | STS temporary credentials session token | +| `role_arn` | IAM role ARN for STS AssumeRole | +| `endpoint` | Custom endpoint for MinIO / Cloudflare R2 | +| `force_path_style` | Use path-style URLs (required for MinIO, default: `false`) | + + + + +```json +{ + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { "...": "..." }, + "object_storage": { + "type": "gcs", + "bucket": "bifrost-logs", + "prefix": "bifrost", + "compress": true, + "project_id": "env.GCP_PROJECT_ID", + "credentials_json": "env.GCS_CREDENTIALS_JSON" + } + } +} +``` + +Omit `credentials_json` to use Application Default Credentials (Workload Identity, GCE metadata, `gcloud auth`). + +| Field | Description | +|-------|-------------| +| `project_id` | GCP project ID (supports `env.` prefix) | +| `credentials_json` | Service account JSON or path — omit for ADC | + + + + +```json +{ + "object_storage": { + "type": "s3", + "bucket": "bifrost-logs", + "prefix": "bifrost", + "compress": false, + "region": "us-east-1", + "endpoint": "http://minio.internal:9000", + "access_key_id": "env.MINIO_ACCESS_KEY", + "secret_access_key": "env.MINIO_SECRET_KEY", + "force_path_style": true + } +} +``` + + + + +--- + +## vector_store + +A vector store is required for [semantic caching](/features/semantic-caching). Choose from Weaviate, Redis/Valkey, Qdrant, or Pinecone. + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "http", + "host": "localhost:8080", + "api_key": "env.WEAVIATE_API_KEY", + "grpc_config": { + "host": "localhost:50051", + "secured": false + } + } + } +} +``` + +| Field | Required | Description | +|-------|----------|-------------| +| `scheme` | Yes | `"http"` or `"https"` | +| `host` | Yes | Weaviate server host and port | +| `api_key` | No | Weaviate API key (supports `env.` prefix) | +| `grpc_config.host` | No | gRPC host for faster vector operations | +| `grpc_config.secured` | No | Use TLS for gRPC connection | + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "redis", + "config": { + "addr": "env.REDIS_ADDR", + "password": "env.REDIS_PASSWORD", + "db": 0, + "use_tls": false + } + } +} +``` + +**AWS MemoryDB (cluster mode):** + +```json +{ + "vector_store": { + "enabled": true, + "type": "redis", + "config": { + "addr": "env.MEMORYDB_ENDPOINT", + "password": "env.MEMORYDB_PASSWORD", + "use_tls": true, + "cluster_mode": true + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `addr` | — | Redis/Valkey address `host:port` (supports `env.` prefix) | +| `password` | — | Redis AUTH password (supports `env.` prefix) | +| `db` | `0` | Redis database number | +| `use_tls` | `false` | Enable TLS | +| `cluster_mode` | `false` | Enable cluster mode (required for MemoryDB; `db` must be `0`) | +| `pool_size` | — | Maximum socket connections | + + + + + +```json +{ + "vector_store": { + "enabled": true, + "type": "qdrant", + "config": { + "host": "env.QDRANT_HOST", + "port": 6334, + "api_key": "env.QDRANT_API_KEY", + "use_tls": false + } + } +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `host` | — | Qdrant server host (supports `env.` prefix) | +| `port` | `6334` | gRPC port | +| `api_key` | — | API key (supports `env.` prefix) | +| `use_tls` | `false` | Enable TLS | + + + + + +Pinecone is external-only. + +```json +{ + "vector_store": { + "enabled": true, + "type": "pinecone", + "config": { + "api_key": "env.PINECONE_API_KEY", + "index_host": "env.PINECONE_INDEX_HOST" + } + } +} +``` + +| Field | Description | +|-------|-------------| +| `api_key` | Pinecone API key (supports `env.` prefix) | +| `index_host` | Index host from Pinecone console (e.g. `your-index.svc.us-east1-gcp.pinecone.io`) | + + + + + +--- + +## Mixed Backend Example + +Run the config store on PostgreSQL (for UI) while keeping logs on SQLite (simpler, cheaper for append-heavy workloads): + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require" + } + }, + + "logs_store": { + "enabled": true, + "type": "sqlite", + "config": { + "path": "./logs.db" + } + } +} +``` + +--- + +## Full Storage Example + +```json +{ + "$schema": "https://www.getbifrost.ai/schema", + "encryption_key": "env.BIFROST_ENCRYPTION_KEY", + + "config_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 5, + "max_open_conns": 50 + } + }, + + "logs_store": { + "enabled": true, + "type": "postgres", + "config": { + "host": "env.PG_HOST", + "port": "5432", + "user": "env.PG_USER", + "password": "env.PG_PASSWORD", + "db_name": "bifrost", + "ssl_mode": "require", + "max_idle_conns": 10, + "max_open_conns": 100 + }, + "retention_days": 90, + "object_storage": { + "type": "s3", + "bucket": "env.S3_BUCKET", + "region": "us-east-1", + "compress": true, + "access_key_id": "env.S3_ACCESS_KEY_ID", + "secret_access_key": "env.S3_SECRET_ACCESS_KEY" + } + }, + + "vector_store": { + "enabled": true, + "type": "weaviate", + "config": { + "scheme": "http", + "host": "weaviate:8080" + } + } +} +``` diff --git a/docs/deployment-guides/helm.mdx b/docs/deployment-guides/helm.mdx index cc09723b8a..3be4c5aa5c 100644 --- a/docs/deployment-guides/helm.mdx +++ b/docs/deployment-guides/helm.mdx @@ -1,740 +1,103 @@ --- -title: "Helm" -description: "Deploy Bifrost on Kubernetes using Helm charts with flexible configuration options" -icon: "helicopter-symbol" +title: "Quick Start" +description: "Deploy Bifrost on Kubernetes using the official Helm chart — quickstart for OSS and Enterprise" +icon: "server" --- -Deploy Bifrost on Kubernetes using the official Helm chart. This is the recommended way to deploy Bifrost on Kubernetes with production-ready defaults and flexible configuration. - -**Latest Chart Version:** 1.5.0 | [View on Artifact Hub](https://artifacthub.io/packages/helm/bifrost/bifrost) +**Latest Chart Version:** 2.1.0 | [View on Artifact Hub](https://artifacthub.io/packages/helm/bifrost/bifrost) + + + + ## Prerequisites - Kubernetes cluster (v1.19+) - `kubectl` configured - Helm 3.2.0+ installed -- (Optional) Persistent Volume provisioner -- (Optional) Ingress controller +- Persistent Volume provisioner (required for SQLite; optional for Postgres-only) If you use PostgreSQL for Bifrost storage, ensure the database is UTF8 encoded. See [PostgreSQL UTF8 Requirement](../quickstart/gateway/setting-up#postgresql-utf8-requirement). -## Quick Start - -### Add Helm Repository +## Step 1 — Add the Helm Repository ```bash helm repo add bifrost https://maximhq.github.io/bifrost/helm-charts helm repo update ``` -### Install Bifrost - -```bash -helm install bifrost bifrost/bifrost --set image.tag=1.3.45 -``` +## Step 2 — Install -The `image.tag` parameter is required. Check [Docker Hub](https://hub.docker.com/r/maximhq/bifrost/tags) for available versions. +The Helm chart ships ready-made values files under `helm-charts/bifrost/values-examples/`. +For example: `sqlite-only.yaml`, `production-ha.yaml`, `external-postgres.yaml`, and `secrets-from-k8s.yaml`. +See the full list here: https://github.com/maximhq/bifrost/tree/main/helm-charts/bifrost/values-examples -This deploys Bifrost with: -- SQLite storage (10Gi PVC) -- Single replica -- ClusterIP service - -### Access Bifrost - -```bash -kubectl port-forward svc/bifrost 8080:8080 -curl http://localhost:8080/metrics -``` - -## Deployment Patterns - - - - -### Development Setup - -Simple setup for local testing and development. - -```bash -helm install bifrost bifrost/bifrost \ - --set image.tag=1.3.45 \ - --set bifrost.providers.openai.keys[0].value="sk-your-key" \ - --set bifrost.providers.openai.keys[0].weight=1 -``` - -**Features:** -- SQLite storage -- Single replica -- No auto-scaling -- ClusterIP service - -**Access:** -```bash -kubectl port-forward svc/bifrost 8080:8080 -``` - - - - - -### Production Setup - -High-availability setup with PostgreSQL and auto-scaling. - -```yaml -# production.yaml -image: - tag: "1.3.45" # Required: specify the Bifrost version - -replicaCount: 3 - -storage: - mode: postgres - -postgresql: - enabled: true - auth: - password: "your-secure-password" - primary: - persistence: - size: 50Gi - resources: - requests: - cpu: 500m - memory: 1Gi - limits: - cpu: 2000m - memory: 2Gi - -autoscaling: - enabled: true - minReplicas: 3 - maxReplicas: 10 - targetCPUUtilizationPercentage: 70 - targetMemoryUtilizationPercentage: 80 - -ingress: - enabled: true - className: nginx - annotations: - cert-manager.io/cluster-issuer: letsencrypt-prod - hosts: - - host: bifrost.yourdomain.com - paths: - - path: / - pathType: Prefix - tls: - - secretName: bifrost-tls - hosts: - - bifrost.yourdomain.com - -resources: - requests: - cpu: 500m - memory: 1Gi - limits: - cpu: 2000m - memory: 2Gi - -bifrost: - encryptionKey: "your-32-byte-encryption-key" - logLevel: info - - client: - dropExcessRequests: true - enableLogging: true - - providers: - openai: - keys: - - value: "sk-..." - weight: 1 - - plugins: - telemetry: - enabled: true - logging: - enabled: true - governance: - enabled: true -``` - -**Install:** -```bash -helm install bifrost bifrost/bifrost -f production.yaml -``` - -**Features:** -- 3 initial replicas (scales 3-10) -- PostgreSQL database -- Ingress with TLS -- Monitoring enabled - - - - - -### AI Workloads with Semantic Caching - -Optimized for high-volume AI inference with caching. - -```yaml -# ai-workload.yaml -image: - tag: "1.3.45" # Required: specify the Bifrost version - -storage: - mode: postgres - -postgresql: - enabled: true - auth: - password: "secure-password" - primary: - persistence: - size: 50Gi - -vectorStore: - enabled: true - type: weaviate - weaviate: - enabled: true - persistence: - size: 50Gi - resources: - requests: - cpu: 500m - memory: 1Gi - limits: - cpu: 2000m - memory: 2Gi - -bifrost: - encryptionKey: "your-encryption-key" - - providers: - openai: - keys: - - value: "sk-..." - weight: 1 - - plugins: - semanticCache: - enabled: true - config: - provider: "openai" - embedding_model: "text-embedding-3-small" - dimension: 1536 - threshold: 0.8 - ttl: "5m" - cache_by_model: true - cache_by_provider: true -``` - -**Install:** -```bash -helm install bifrost bifrost/bifrost -f ai-workload.yaml -``` - -**Features:** -- PostgreSQL for config/logs -- Weaviate for vector storage -- Semantic caching enabled -- Optimized for AI workloads - - - - - -### Multi-Provider Setup - -Support multiple LLM providers with load balancing. - -```yaml -# multi-provider.yaml -image: - tag: "1.3.45" # Required: specify the Bifrost version - -bifrost: - encryptionKey: "your-encryption-key" - - client: - enableLogging: true - allowDirectKeys: false - - providers: - openai: - keys: - - value: "sk-..." - weight: 2 - anthropic: - keys: - - value: "sk-ant-..." - weight: 1 - gemini: - keys: - - value: "..." - weight: 1 - cohere: - keys: - - value: "..." - weight: 1 - - plugins: - telemetry: - enabled: true - logging: - enabled: true -``` - -**Install:** -```bash -helm install bifrost bifrost/bifrost -f multi-provider.yaml -``` - -**Features:** -- Multiple provider support -- Weighted load balancing -- Request/response logging -- Telemetry enabled - - - - - -### External Database - -Use existing PostgreSQL instance. - -```yaml -# external-db.yaml -image: - tag: "1.3.45" # Required: specify the Bifrost version - -storage: - mode: postgres - -postgresql: - enabled: false - external: - enabled: true - host: "postgres.example.com" - port: 5432 - user: "bifrost" - password: "your-password" - database: "bifrost" - sslMode: "require" - -bifrost: - encryptionKey: "your-encryption-key" - - providers: - openai: - keys: - - value: "sk-..." - weight: 1 -``` - -**Install:** -```bash -helm install bifrost bifrost/bifrost -f external-db.yaml -``` - -**Features:** -- Uses external PostgreSQL -- No embedded database -- SSL connection support - - - - - -### Using Kubernetes Secrets - -Store all sensitive values in Kubernetes secrets instead of values files. - -**Prerequisites:** Create Kubernetes secrets first: - -```bash -# PostgreSQL password -kubectl create secret generic postgres-credentials \ - --from-literal=password='your-postgres-password' - -# Encryption key -kubectl create secret generic bifrost-encryption \ - --from-literal=key='your-encryption-key' - -# Provider API keys -kubectl create secret generic provider-api-keys \ - --from-literal=openai-api-key='sk-...' \ - --from-literal=anthropic-api-key='sk-ant-...' - -# Qdrant API key (if using) -kubectl create secret generic qdrant-credentials \ - --from-literal=api-key='your-qdrant-api-key' -``` - -```yaml -# secrets-config.yaml -image: - tag: "1.3.45" - -storage: - mode: postgres - -# External PostgreSQL with secret reference -postgresql: - enabled: false - external: - enabled: true - host: "postgres.example.com" - port: 5432 - user: "bifrost" - database: "bifrost" - sslMode: "require" - existingSecret: "postgres-credentials" - passwordKey: "password" - -# Vector store with secret reference -vectorStore: - enabled: true - type: qdrant - qdrant: - external: - enabled: true - host: "qdrant.example.com" - port: 6334 - existingSecret: "qdrant-credentials" - apiKeyKey: "api-key" - -bifrost: - # Encryption key from secret - encryptionKeySecret: - name: "bifrost-encryption" - key: "key" - - # Provider configs using env var references - providers: - openai: - keys: - - value: "env.OPENAI_API_KEY" - weight: 1 - anthropic: - keys: - - value: "env.ANTHROPIC_API_KEY" - weight: 1 - - # Inject provider secrets as env vars - providerSecrets: - openai: - existingSecret: "provider-api-keys" - key: "openai-api-key" - envVar: "OPENAI_API_KEY" - anthropic: - existingSecret: "provider-api-keys" - key: "anthropic-api-key" - envVar: "ANTHROPIC_API_KEY" -``` - -**Install:** -```bash -helm install bifrost bifrost/bifrost -f secrets-config.yaml -``` - -**Features:** -- No sensitive values in values files -- Secrets managed by Kubernetes -- Works with external secret managers (Vault, AWS Secrets Manager via External Secrets Operator) - - - - -## Configuration - -### Key Parameters - -| Parameter | Description | Default | -|-----------|-------------|---------| -| `image.tag` | **Required.** Bifrost image version (e.g., 1.3.45) | `""` | -| `replicaCount` | Number of replicas | `1` | -| `storage.mode` | Storage backend (sqlite/postgres) | `sqlite` | -| `storage.persistence.size` | PVC size for SQLite | `10Gi` | -| `postgresql.enabled` | Deploy PostgreSQL | `false` | -| `vectorStore.enabled` | Enable vector store | `false` | -| `vectorStore.type` | Vector store type (weaviate/redis/qdrant). Use `redis` for Redis or Valkey-compatible services | `none` | -| `bifrost.encryptionKey` | Encryption key | `""` | -| `ingress.enabled` | Enable ingress | `false` | -| `autoscaling.enabled` | Enable HPA | `false` | - -### Secret Reference Parameters - -Use existing Kubernetes secrets instead of plain-text values: - -| Parameter | Description | Default | -|-----------|-------------|---------| -| `bifrost.encryptionKeySecret.name` | Secret name for encryption key | `""` | -| `bifrost.encryptionKeySecret.key` | Key within the secret | `""` | -| `postgresql.external.existingSecret` | Secret name for PostgreSQL password | `""` | -| `postgresql.external.passwordKey` | Key within the secret | `"password"` | -| `vectorStore.redis.external.existingSecret` | Secret name for Redis password | `""` | -| `vectorStore.redis.external.passwordKey` | Key within the secret | `"password"` | -| `vectorStore.weaviate.external.existingSecret` | Secret name for Weaviate API key | `""` | -| `vectorStore.weaviate.external.apiKeyKey` | Key within the secret | `"api-key"` | -| `vectorStore.qdrant.external.existingSecret` | Secret name for Qdrant API key | `""` | -| `vectorStore.qdrant.external.apiKeyKey` | Key within the secret | `"api-key"` | -| `bifrost.plugins.maxim.secretRef.name` | Secret name for Maxim API key | `""` | -| `bifrost.plugins.maxim.secretRef.key` | Key within the secret | `"api-key"` | -| `bifrost.providerSecrets..existingSecret` | Secret name for provider API key | `""` | -| `bifrost.providerSecrets..key` | Key within the secret | `"api-key"` | -| `bifrost.providerSecrets..envVar` | Environment variable name to inject | `""` | - -### Provider Configuration - -Add provider keys via values file: - -```yaml -bifrost: - providers: - openai: - keys: - - value: "sk-..." - weight: 1 - anthropic: - keys: - - value: "sk-ant-..." - weight: 1 -``` - -Or via command line: - -```bash -helm install bifrost bifrost/bifrost \ - --set image.tag=1.3.45 \ - --set bifrost.providers.openai.keys[0].value="sk-..." \ - --set bifrost.providers.openai.keys[0].weight=1 -``` - -#### Using Environment Variables for Provider Keys - -Bifrost supports `env.VAR_NAME` syntax to reference environment variables. Combined with `providerSecrets`, you can keep API keys in Kubernetes secrets: - -```yaml -bifrost: - providers: - openai: - keys: - - value: "env.OPENAI_API_KEY" # References environment variable - weight: 1 - - # Inject secrets as environment variables - providerSecrets: - openai: - existingSecret: "my-openai-secret" - key: "api-key" - envVar: "OPENAI_API_KEY" -``` - -This pattern: -1. Creates a Kubernetes secret with the API key -2. Injects the secret as an environment variable (`OPENAI_API_KEY`) -3. Bifrost resolves `env.OPENAI_API_KEY` at runtime - -### Plugin Configuration - -Enable and configure plugins: - -```yaml -bifrost: - plugins: - telemetry: - enabled: true - config: {} - - logging: - enabled: true - config: {} - - governance: - enabled: true - config: - is_vk_mandatory: false - - semanticCache: - enabled: true - config: - provider: "openai" - embedding_model: "text-embedding-3-small" - dimension: 1536 - threshold: 0.8 - ttl: "5m" - cache_by_model: true - cache_by_provider: true -``` - -## Operations - -### Upgrade - -```bash -# Update repository -helm repo update - -# Upgrade with same values -helm upgrade bifrost bifrost/bifrost --reuse-values - -# Upgrade with new values -helm upgrade bifrost bifrost/bifrost -f your-values.yaml -``` - -### Rollback - -```bash -# View release history -helm history bifrost - -# Rollback to previous version -helm rollback bifrost - -# Rollback to specific revision -helm rollback bifrost 2 -``` - -### Uninstall - -```bash -# Uninstall release -helm uninstall bifrost - -# Delete PVCs (if you want to remove data) -kubectl delete pvc -l app.kubernetes.io/instance=bifrost -``` - -### Scale - -```bash -# Scale manually -kubectl scale deployment bifrost --replicas=5 - -# Or update via Helm -helm upgrade bifrost bifrost/bifrost \ - --set replicaCount=5 \ - --reuse-values -``` - -## Monitoring - -### Prometheus Metrics - -Bifrost exposes Prometheus metrics at `/metrics`. - -Enable ServiceMonitor for automatic scraping: - -```yaml -serviceMonitor: - enabled: true - interval: 30s - scrapeTimeout: 10s -``` - -### Health Checks - -Check pod health: - -```bash -# View pod status -kubectl get pods -l app.kubernetes.io/name=bifrost - -# Check logs -kubectl logs -l app.kubernetes.io/name=bifrost --tail=100 - -# Describe pod -kubectl describe pod -l app.kubernetes.io/name=bifrost -``` - -### Metrics Endpoints - -```bash -# Port forward -kubectl port-forward svc/bifrost 8080:8080 - -# Check metrics -curl http://localhost:8080/metrics - -# Check health -curl http://localhost:8080/health -``` - -## Troubleshooting + + -### Pod Not Starting +Fastest way to get running. Bifrost deploys as a StatefulSet with a 10Gi PVC for SQLite. ```bash -# Check events -kubectl describe pod -l app.kubernetes.io/name=bifrost - -# Check logs -kubectl logs -l app.kubernetes.io/name=bifrost +kubectl create secret generic bifrost-encryption-key \ + --from-literal=encryption-key="$(openssl rand -base64 32)" -# Common issues: -# - Image pull errors: Check repository access -# - PVC binding: Check PVC status -# - Config errors: Validate ConfigMap +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + --set bifrost.encryptionKeySecret.name="bifrost-encryption-key" \ + --set bifrost.encryptionKeySecret.key="encryption-key" ``` -### Database Connection Issues + + + +Add your first provider key at install time: ```bash -# For embedded PostgreSQL -kubectl exec -it deployment/bifrost-postgresql -- psql -U bifrost +kubectl create secret generic bifrost-encryption-key \ + --from-literal=encryption-key="$(openssl rand -base64 32)" -# Check connectivity from pod -kubectl exec -it deployment/bifrost -- nc -zv bifrost-postgresql 5432 +kubectl create secret generic provider-keys \ + --from-literal=openai-api-key='sk-your-key' -# Check secret -kubectl get secret bifrost-config -o yaml +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + --set bifrost.encryptionKeySecret.name="bifrost-encryption-key" \ + --set bifrost.encryptionKeySecret.key="encryption-key" \ + --set 'bifrost.providers.openai.keys[0].name=primary' \ + --set 'bifrost.providers.openai.keys[0].value=env.OPENAI_API_KEY' \ + --set 'bifrost.providers.openai.keys[0].weight=1' \ + --set bifrost.providerSecrets.openai.existingSecret="provider-keys" \ + --set bifrost.providerSecrets.openai.key="openai-api-key" \ + --set bifrost.providerSecrets.openai.envVar="OPENAI_API_KEY" ``` -### High Memory Usage - -```bash -# Check resource usage -kubectl top pods -l app.kubernetes.io/name=bifrost - -# Increase limits -helm upgrade bifrost bifrost/bifrost \ - --set resources.limits.memory=4Gi \ - --reuse-values -``` + + -### Ingress Not Working +High-availability setup — 3 replicas, PostgreSQL, autoscaling, ingress. ```bash -# Check ingress status -kubectl describe ingress bifrost +# 1. Create secrets +kubectl create secret generic bifrost-encryption-key \ + --from-literal=encryption-key="$(openssl rand -base64 32)" -# Check ingress controller logs -kubectl logs -n ingress-nginx -l app.kubernetes.io/name=ingress-nginx +kubectl create secret generic postgres-credentials \ + --from-literal=password="$(openssl rand -base64 32)" -# Verify DNS -nslookup bifrost.yourdomain.com +kubectl create secret generic provider-keys \ + --from-literal=openai-api-key='sk-...' ``` -## Advanced Configuration - -### Custom Values File - -Create `my-values.yaml`: - ```yaml +# production.yaml image: - tag: "1.3.45" # Required: specify the Bifrost version + tag: "v1.4.11" replicaCount: 3 @@ -744,105 +107,157 @@ storage: postgresql: enabled: true auth: - password: "secure-password" + username: bifrost + database: bifrost + existingSecret: "postgres-credentials" + secretKeys: + adminPasswordKey: "password" + primary: + persistence: + size: 50Gi + resources: + requests: + cpu: 500m + memory: 1Gi + limits: + cpu: 2000m + memory: 2Gi autoscaling: enabled: true minReplicas: 3 maxReplicas: 10 + targetCPUUtilizationPercentage: 70 + targetMemoryUtilizationPercentage: 80 ingress: enabled: true className: nginx + annotations: + cert-manager.io/cluster-issuer: letsencrypt-prod hosts: - - host: bifrost.example.com + - host: bifrost.yourdomain.com paths: - path: / pathType: Prefix + tls: + - secretName: bifrost-tls + hosts: + - bifrost.yourdomain.com + +resources: + requests: + cpu: 500m + memory: 1Gi + limits: + cpu: 2000m + memory: 2Gi bifrost: - encryptionKey: "your-32-byte-key" + encryptionKeySecret: + name: "bifrost-encryption-key" + key: "encryption-key" + + client: + initialPoolSize: 500 + dropExcessRequests: true + enableLogging: true + providers: openai: keys: - - value: "sk-..." + - name: "openai-primary" + value: "env.OPENAI_API_KEY" weight: 1 -``` -Install: + providerSecrets: + openai: + existingSecret: "provider-keys" + key: "openai-api-key" + envVar: "OPENAI_API_KEY" + + plugins: + telemetry: + enabled: true + logging: + enabled: true + governance: + enabled: true +``` ```bash -helm install bifrost bifrost/bifrost -f my-values.yaml +# 2. Install +helm install bifrost bifrost/bifrost -f production.yaml ``` -### Environment Variables + + -Add custom environment variables: + +`image.tag` is required — the chart will not start without it. Check [Docker Hub](https://hub.docker.com/r/maximhq/bifrost/tags) for available versions. + -```yaml -env: - - name: CUSTOM_VAR - value: "custom-value" - -envFrom: - - secretRef: - name: bifrost-secrets - - configMapRef: - name: bifrost-config -``` +## Step 3 — Verify -### Node Affinity +```bash +# Check pods are running +kubectl get pods -l app.kubernetes.io/name=bifrost -Deploy to specific nodes: +# Port forward and hit the health endpoint +kubectl port-forward svc/bifrost 8080:8080 +curl http://localhost:8080/health -```yaml -nodeSelector: - node-type: ai-workload +# Check Prometheus metrics +curl http://localhost:8080/metrics +``` -affinity: - podAntiAffinity: - requiredDuringSchedulingIgnoredDuringExecution: - - labelSelector: - matchLabels: - app.kubernetes.io/name: bifrost - topologyKey: kubernetes.io/hostname +## Step 4 — Configure Providers & Plugins -tolerations: - - key: "gpu" - operator: "Equal" - value: "true" - effect: "NoSchedule" +```bash +# Make your first inference call +curl http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Hello from Bifrost!"}] + }' ``` -## Enterprise Deployment +Next steps: jump to [Next Steps](#next-steps). + + + + -For enterprise customers, Bifrost provides dedicated container images hosted in private registries with additional features, support, and SLAs. +Enterprise customers receive dedicated container images in a private registry, along with additional features, SLAs, and compliance documentation. [Book a demo](https://calendly.com/maximai/bifrost-demo) to know more about our enterprise features. -### Private Container Registry - -Enterprise customers receive access to Bifrost images in a private container registry. To use your enterprise registry, override the `image.repository` with your provided registry URL: +## Prerequisites - - +- Kubernetes cluster (v1.19+) +- `kubectl` configured +- Helm 3.2.0+ installed +- Enterprise registry credentials (provided by Maxim) -```yaml -# enterprise-gcp.yaml -image: - repository: us-west1-docker.pkg.dev/bifrost-enterprise/your-org/bifrost - tag: "latest" +## Step 1 — Add the Helm Repository -imagePullSecrets: - - name: gcr-secret +```bash +helm repo add bifrost https://maximhq.github.io/bifrost/helm-charts +helm repo update ``` -**Create the pull secret:** +## Step 2 — Create Pull Secret + +Create a Kubernetes image pull secret for our private enterprise registry: + + + ```bash -kubectl create secret docker-registry gcr-secret \ +kubectl create secret docker-registry enterprise-registry-secret \ --docker-server=us-west1-docker.pkg.dev \ --docker-username=_json_key \ --docker-password="$(cat service-account-key.json)" \ @@ -852,46 +267,22 @@ kubectl create secret docker-registry gcr-secret \ -```yaml -# enterprise-aws.yaml -image: - repository: 123456789.dkr.ecr.us-east-1.amazonaws.com/bifrost - tag: "latest" - -imagePullSecrets: - - name: ecr-secret -``` - -**Create the pull secret:** - ```bash -kubectl create secret docker-registry ecr-secret \ +kubectl create secret docker-registry enterprise-registry-secret \ --docker-server=123456789.dkr.ecr.us-east-1.amazonaws.com \ --docker-username=AWS \ --docker-password=$(aws ecr get-login-password --region us-east-1) ``` -ECR tokens expire after 12 hours. Consider using [ECR Credential Helper](https://github.com/awslabs/amazon-ecr-credential-helper) or an operator like [ECR Registry Creds](https://github.com/upmc-enterprises/registry-creds) for automatic token refresh. +ECR tokens expire after 12 hours. Use the [ECR Credential Helper](https://github.com/awslabs/amazon-ecr-credential-helper) or [ECR Registry Creds operator](https://github.com/upmc-enterprises/registry-creds) for automatic refresh. -```yaml -# enterprise-azure.yaml -image: - repository: yourregistry.azurecr.io/bifrost - tag: "latest" - -imagePullSecrets: - - name: acr-secret -``` - -**Create the pull secret:** - ```bash -kubectl create secret docker-registry acr-secret \ +kubectl create secret docker-registry enterprise-registry-secret \ --docker-server=yourregistry.azurecr.io \ --docker-username= \ --docker-password= @@ -900,20 +291,8 @@ kubectl create secret docker-registry acr-secret \ -```yaml -# enterprise-self-hosted.yaml -image: - repository: registry.yourcompany.com/ai/bifrost - tag: "latest" - -imagePullSecrets: - - name: private-registry-secret -``` - -**Create the pull secret:** - ```bash -kubectl create secret docker-registry private-registry-secret \ +kubectl create secret docker-registry enterprise-registry-secret \ --docker-server=registry.yourcompany.com \ --docker-username= \ --docker-password= @@ -922,14 +301,30 @@ kubectl create secret docker-registry private-registry-secret \ -### Full Enterprise Configuration +## Step 3 — Create Required Secrets + +```bash +# Encryption key +kubectl create secret generic bifrost-encryption \ + --from-literal=key="$(openssl rand -base64 32)" + +# Provider API keys +kubectl create secret generic provider-keys \ + --from-literal=openai-api-key='sk-...' \ + --from-literal=anthropic-api-key='sk-ant-...' + +# Admin credentials (for dashboard + governance) +kubectl create secret generic bifrost-admin-credentials \ + --from-literal=username='admin' \ + --from-literal=password='secure-admin-password' +``` -Complete example for enterprise deployments with all recommended settings: +## Step 4 — Install ```yaml -# enterprise-full.yaml +# enterprise.yaml image: - # Your enterprise registry URL (provided by Maxim) + # Registry URL provided by Maxim repository: us-west1-docker.pkg.dev/bifrost-enterprise/your-org/bifrost tag: "latest" @@ -938,7 +333,6 @@ imagePullSecrets: replicaCount: 3 -# Production-grade resources resources: requests: cpu: 1000m @@ -947,7 +341,6 @@ resources: cpu: 4000m memory: 8Gi -# Auto-scaling for high availability autoscaling: enabled: true minReplicas: 3 @@ -955,14 +348,13 @@ autoscaling: targetCPUUtilizationPercentage: 70 targetMemoryUtilizationPercentage: 80 -# PostgreSQL storage storage: mode: postgres postgresql: enabled: true auth: - password: "secure-password" # Use existingSecret in production + password: "secure-password" # use existingSecret in production primary: persistence: size: 100Gi @@ -974,7 +366,6 @@ postgresql: cpu: 4000m memory: 8Gi -# Vector store for semantic caching vectorStore: enabled: true type: weaviate @@ -983,7 +374,6 @@ vectorStore: persistence: size: 100Gi -# Ingress with TLS ingress: enabled: true className: nginx @@ -1000,17 +390,16 @@ ingress: hosts: - bifrost.yourcompany.com -# Bifrost configuration bifrost: encryptionKeySecret: name: "bifrost-encryption" key: "key" - + client: initialPoolSize: 1000 dropExcessRequests: true enableLogging: true - disableContentLogging: false # Set to true for compliance + disableContentLogging: false # set true for HIPAA/compliance logRetentionDays: 365 enforceGovernanceHeader: true allowDirectKeys: false @@ -1018,29 +407,29 @@ bifrost: allowedOrigins: - "https://yourcompany.com" - "https://*.yourcompany.com" - - # Use secrets for provider keys + providers: openai: keys: - - value: "env.OPENAI_API_KEY" + - name: "openai-primary" + value: "env.OPENAI_API_KEY" weight: 1 anthropic: keys: - - value: "env.ANTHROPIC_API_KEY" + - name: "anthropic-primary" + value: "env.ANTHROPIC_API_KEY" weight: 1 - + providerSecrets: openai: - existingSecret: "provider-api-keys" + existingSecret: "provider-keys" key: "openai-api-key" envVar: "OPENAI_API_KEY" anthropic: - existingSecret: "provider-api-keys" + existingSecret: "provider-keys" key: "anthropic-api-key" envVar: "ANTHROPIC_API_KEY" - - # Governance with authentication + governance: authConfig: isEnabled: true @@ -1048,8 +437,7 @@ bifrost: existingSecret: "bifrost-admin-credentials" usernameKey: "username" passwordKey: "password" - - # Enable all plugins + plugins: telemetry: enabled: true @@ -1068,7 +456,6 @@ bifrost: threshold: 0.85 ttl: "1h" -# Pod distribution affinity: podAntiAffinity: requiredDuringSchedulingIgnoredDuringExecution: @@ -1078,52 +465,159 @@ affinity: topologyKey: kubernetes.io/hostname ``` -### Enterprise Prerequisites +```bash +helm install bifrost bifrost/bifrost -f enterprise.yaml +``` + +Next steps: jump to [Next Steps](#next-steps). + +## Enterprise Support + +Enterprise customers have access to: +- Dedicated Slack channel for support +- Priority bug fixes and feature requests +- Custom feature development +- SLA guarantees +- Compliance documentation (SOC2, HIPAA, etc.) + +Contact [support@getmaxim.ai](mailto:support@getmaxim.ai) for support. + + + + + +--- + +## Operations -Before deploying, create the required secrets: +### Upgrade ```bash -# 1. Registry pull secret (see registry-specific instructions above) +helm repo update -# 2. Encryption key -kubectl create secret generic bifrost-encryption \ - --from-literal=key='your-32-byte-encryption-key' +# Upgrade reusing all existing values +helm upgrade bifrost bifrost/bifrost --reuse-values -# 3. Provider API keys -kubectl create secret generic provider-api-keys \ - --from-literal=openai-api-key='sk-...' \ - --from-literal=anthropic-api-key='sk-ant-...' +# Upgrade with new values +helm upgrade bifrost bifrost/bifrost -f your-values.yaml -# 4. Admin credentials (for governance) -kubectl create secret generic bifrost-admin-credentials \ - --from-literal=username='admin' \ - --from-literal=password='secure-admin-password' +# Upgrade and override a single field +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set image.tag=v1.4.11 ``` -### Install Enterprise Build +### Rollback ```bash -helm install bifrost bifrost/bifrost -f enterprise-full.yaml +helm history bifrost +helm rollback bifrost # to previous revision +helm rollback bifrost 2 # to specific revision ``` -### Enterprise Support +### Scale -Enterprise customers have access to: -- Dedicated Slack channel for support -- Priority bug fixes and feature requests -- Custom feature development -- SLA guarantees -- Compliance documentation (SOC2, HIPAA, etc.) +```bash +kubectl scale deployment bifrost --replicas=5 + +# Or via Helm +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set replicaCount=5 +``` + +### Uninstall + +```bash +helm uninstall bifrost + +# Also remove PVCs (permanently deletes all data) +kubectl delete pvc -l app.kubernetes.io/instance=bifrost +``` + +--- + +## Monitoring + +### Prometheus Metrics + +Bifrost exposes Prometheus metrics at `/metrics`. + +Enable ServiceMonitor for automatic scraping: + +```yaml +serviceMonitor: + enabled: true + interval: 30s + scrapeTimeout: 10s +``` -Contact [support@getmaxim.ai](mailto:support@getmaxim.ai) for enterprise support. +### Health Checks + +Check pod health: + +```bash +# View pod status +kubectl get pods -l app.kubernetes.io/name=bifrost + +# Check logs +kubectl logs -l app.kubernetes.io/name=bifrost --tail=100 + +# Describe pod +kubectl describe pod -l app.kubernetes.io/name=bifrost +``` + +### Metrics Endpoints + +```bash +# Port forward +kubectl port-forward svc/bifrost 8080:8080 + +# Check metrics +curl http://localhost:8080/metrics + +# Check health +curl http://localhost:8080/health +``` + +--- + +## Configuration Guides + + + + All parameters, secret references, advanced config, example patterns + + + Pool size, logging, CORS, header filtering, compat shims, MCP settings + + + OpenAI, Anthropic, Azure, Bedrock, Vertex, Groq, self-hosted + + + SQLite, PostgreSQL, object storage for logs, vector stores + + + Telemetry, logging, semantic cache, OTel, Datadog, governance + + + Budgets, rate limits, virtual keys, routing rules + + + Multi-replica HA, gossip, peer discovery + + + Pod startup, database, ingress, PVC, secrets, performance + + + +--- ## Resources - [Helm Chart Repository](https://github.com/maximhq/bifrost/tree/main/helm-charts) - [Artifact Hub](https://artifacthub.io/packages/helm/bifrost/bifrost) -- [Complete Installation Guide](https://github.com/maximhq/bifrost/blob/main/helm-charts/INSTALL.md) - [Example Configurations](https://github.com/maximhq/bifrost/tree/main/helm-charts/bifrost/values-examples) -- [Kubernetes Secrets Example](https://github.com/maximhq/bifrost/blob/main/helm-charts/bifrost/values-examples/secrets-from-k8s.yaml) - [GitHub Issues](https://github.com/maximhq/bifrost/issues) ## Next Steps diff --git a/docs/deployment-guides/helm/client.mdx b/docs/deployment-guides/helm/client.mdx new file mode 100644 index 0000000000..b3fd2dc968 --- /dev/null +++ b/docs/deployment-guides/helm/client.mdx @@ -0,0 +1,316 @@ +--- +title: "Client Configuration" +description: "Configure the Bifrost client: connection pool, logging, CORS, header filtering, compat shims, and MCP settings" +icon: "gear" +--- + +The `bifrost.client` block controls how Bifrost manages its internal worker pool, request logging, authentication enforcement, header policies, SDK compatibility shims, and MCP agent behaviour. All settings map directly to the `client` section of the rendered `config.json`. + +--- + +## Connection Pool + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.initialPoolSize` | Pre-allocated worker goroutines per provider queue | `300` | +| `bifrost.client.dropExcessRequests` | Drop requests when queue is full instead of waiting | `false` | + +A larger pool reduces latency spikes under burst load at the cost of higher baseline memory. For production workloads with multiple providers, `1000` is a common starting point. + +```yaml +# client-pool.yaml +image: + tag: "v1.4.11" + +bifrost: + client: + initialPoolSize: 1000 + dropExcessRequests: true # Return 429 instead of queuing indefinitely +``` + +```bash +helm install bifrost bifrost/bifrost -f client-pool.yaml + +# Or set inline +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set bifrost.client.initialPoolSize=1000 \ + --set bifrost.client.dropExcessRequests=true +``` + +--- + +## Request & Response Logging + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.enableLogging` | Log all LLM requests and responses | `true` | +| `bifrost.client.disableContentLogging` | Strip message content from logs (keeps metadata) | `false` | +| `bifrost.client.logRetentionDays` | Days to retain log entries in the store | `365` | +| `bifrost.client.loggingHeaders` | HTTP request headers to capture in log metadata | `[]` | + +Set `disableContentLogging: true` for HIPAA / PCI compliance workloads where message content must not be persisted. + +```yaml +bifrost: + client: + enableLogging: true + disableContentLogging: true # PII / compliance: store metadata only + logRetentionDays: 90 + loggingHeaders: + - "x-request-id" + - "x-user-id" +``` + +```bash +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set bifrost.client.disableContentLogging=true \ + --set bifrost.client.logRetentionDays=90 +``` + +--- + +## Security & CORS + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.allowedOrigins` | CORS allowed origins | `["*"]` | +| `bifrost.client.allowDirectKeys` | Allow callers to pass provider keys directly in requests | `false` | +| `bifrost.client.enforceGovernanceHeader` | Require `x-bf-vk` virtual-key header on every request | `false` | +| `bifrost.client.maxRequestBodySizeMb` | Maximum allowed request body size | `100` | +| `bifrost.client.whitelistedRoutes` | Routes that bypass auth middleware | `[]` | + +```yaml +bifrost: + client: + allowedOrigins: + - "https://app.yourdomain.com" + - "https://admin.yourdomain.com" + allowDirectKeys: false # Prevent callers from supplying raw provider keys + enforceGovernanceHeader: true # Every request must carry a virtual key + maxRequestBodySizeMb: 50 + whitelistedRoutes: + - "/health" + - "/metrics" +``` + +```bash +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + --set bifrost.client.enforceGovernanceHeader=true \ + --set bifrost.client.allowDirectKeys=false +``` + +--- + +## Header Filtering + +Controls which `x-bf-eh-*` headers are forwarded to upstream LLM providers. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.headerFilterConfig.allowlist` | Only these headers are forwarded (whitelist mode) | `[]` | +| `bifrost.client.headerFilterConfig.denylist` | These headers are always blocked | `[]` | +| `bifrost.client.requiredHeaders` | Headers that must be present on every request | `[]` | +| `bifrost.client.allowedHeaders` | Additional headers permitted for CORS and WebSocket | `[]` | + +When both lists are empty, all `x-bf-eh-*` headers pass through. Specifying an `allowlist` enables strict whitelist mode — only listed headers are forwarded. + +```yaml +bifrost: + client: + headerFilterConfig: + allowlist: + - "x-bf-eh-anthropic-version" + - "x-bf-eh-openai-beta" + denylist: [] + requiredHeaders: + - "x-request-id" +``` + +--- + +## Authentication + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.authConfig.isEnabled` | Enable username/password auth for the API and dashboard | `false` | +| `bifrost.authConfig.adminUsername` | Admin username (plain text, prefer secret) | `""` | +| `bifrost.authConfig.adminPassword` | Admin password (plain text, prefer secret) | `""` | +| `bifrost.authConfig.existingSecret` | Kubernetes Secret name for credentials | `""` | +| `bifrost.authConfig.usernameKey` | Key within the secret for username | `"username"` | +| `bifrost.authConfig.passwordKey` | Key within the secret for password | `"password"` | +| `bifrost.authConfig.disableAuthOnInference` | Skip auth check on `/v1/*` inference routes | `false` | + +```bash +# Create secret first +kubectl create secret generic bifrost-admin \ + --from-literal=username='admin' \ + --from-literal=password='your-secure-password' +``` + +```yaml +bifrost: + authConfig: + isEnabled: true + disableAuthOnInference: false + existingSecret: "bifrost-admin" + usernameKey: "username" + passwordKey: "password" +``` + +```bash +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + -f auth-values.yaml +``` + +--- + +## Encryption + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.encryptionKey` | 32-byte encryption key (plain text — use secret in production) | `""` | +| `bifrost.encryptionKeySecret.name` | Kubernetes Secret name containing the key | `""` | +| `bifrost.encryptionKeySecret.key` | Key within the secret | `"encryption-key"` | + +Always use a Kubernetes Secret in production: + +```bash +kubectl create secret generic bifrost-encryption \ + --from-literal=encryption-key='your-32-byte-encryption-key-here' +``` + +```yaml +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" +``` + +```bash +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + -f encryption-values.yaml +``` + +--- + +## Async Jobs & Database Pings + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.disableDbPingsInHealth` | Exclude DB connectivity from `/health` checks | `false` | +| `bifrost.client.asyncJobResultTTL` | TTL (seconds) for async job results | `3600` | + +--- + +## Compat Shims + +Compatibility flags that let Bifrost silently adapt request/response shapes for SDK integrations: + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.compat.convertTextToChat` | Wrap legacy text completions as chat messages | `false` | +| `bifrost.client.compat.convertChatToResponses` | Translate chat completions to Responses API format | `false` | +| `bifrost.client.compat.shouldDropParams` | Silently drop unsupported parameters instead of erroring | `false` | +| `bifrost.client.compat.shouldConvertParams` | Auto-convert parameter names across provider schemas | `false` | + +```yaml +bifrost: + client: + compat: + shouldDropParams: true # Useful when proxying mixed SDK traffic + convertTextToChat: true # For clients using the legacy /v1/completions endpoint +``` + +--- + +## Prometheus Labels + +Add custom labels to every Prometheus metric emitted by Bifrost: + +```yaml +bifrost: + client: + prometheusLabels: + - name: "environment" + value: "production" + - name: "region" + value: "us-east-1" +``` + +--- + +## MCP Agent Settings + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.mcpAgentDepth` | Maximum tool-call recursion depth for MCP agent mode | `10` | +| `bifrost.client.mcpToolExecutionTimeout` | Timeout per tool execution in seconds | `30` | +| `bifrost.client.mcpCodeModeBindingLevel` | Code mode binding level (`server` or `tool`) | `""` | +| `bifrost.client.mcpToolSyncInterval` | Global tool sync interval in minutes (`0` = disabled) | `0` | + +```yaml +bifrost: + client: + mcpAgentDepth: 15 + mcpToolExecutionTimeout: 60 +``` + +--- + +## Full Example + +```yaml +# client-full.yaml +image: + tag: "v1.4.11" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" + + authConfig: + isEnabled: true + disableAuthOnInference: false + existingSecret: "bifrost-admin" + usernameKey: "username" + passwordKey: "password" + + client: + initialPoolSize: 1000 + dropExcessRequests: true + allowedOrigins: + - "https://app.yourdomain.com" + enableLogging: true + disableContentLogging: false + logRetentionDays: 90 + enforceGovernanceHeader: true + allowDirectKeys: false + maxRequestBodySizeMb: 100 + headerFilterConfig: + allowlist: [] + denylist: [] + prometheusLabels: + - name: "environment" + value: "production" + mcpAgentDepth: 10 + mcpToolExecutionTimeout: 30 +``` + +```bash +# Create prerequisites +kubectl create secret generic bifrost-encryption \ + --from-literal=encryption-key='your-32-byte-encryption-key-here' + +kubectl create secret generic bifrost-admin \ + --from-literal=username='admin' \ + --from-literal=password='your-secure-password' + +# Install +helm install bifrost bifrost/bifrost -f client-full.yaml +``` diff --git a/docs/deployment-guides/helm/cluster.mdx b/docs/deployment-guides/helm/cluster.mdx new file mode 100644 index 0000000000..ea86536e5c --- /dev/null +++ b/docs/deployment-guides/helm/cluster.mdx @@ -0,0 +1,513 @@ +--- +title: "Cluster Mode & HA" +description: "Run Bifrost in a multi-replica cluster with gossip-based peer discovery, distributed state sync, and high-availability configuration" +icon: "network-wired" +--- + +Cluster mode enables multiple Bifrost replicas to share state — rate limits, budget counters, and governance data — across pods. When `bifrost.cluster.enabled` is `false` (the default), each replica operates independently and state is only shared via the database. + + +Cluster mode requires **PostgreSQL** as the storage backend. SQLite is single-node only. + + +## When to Use Cluster Mode + +| Scenario | Recommendation | +|----------|---------------| +| Single replica | Not needed | +| Multiple replicas, shared DB only | Optional — DB provides eventual consistency | +| Multiple replicas with strict per-minute rate limiting | **Enable cluster mode** — in-memory counters are synced via gossip | +| Geographic multi-region | Enable cluster mode with DNS or Consul discovery | + +--- + +## Basic Cluster Setup + +```yaml +# cluster-values.yaml +image: + tag: "v1.4.11" + +replicaCount: 3 + +storage: + mode: postgres + +postgresql: + external: + enabled: true + host: "your-postgres-host.example.com" + port: 5432 + user: bifrost + database: bifrost + sslMode: require + existingSecret: "postgres-credentials" + passwordKey: "password" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" + + cluster: + enabled: true + gossip: + port: 7946 + config: + timeoutSeconds: 10 + successThreshold: 3 + failureThreshold: 3 + +# Spread replicas across nodes for true HA +affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchLabels: + app.kubernetes.io/name: bifrost + topologyKey: kubernetes.io/hostname + +# Conservative scale-down: avoid killing pods mid-stream +autoscaling: + enabled: true + minReplicas: 3 + maxReplicas: 10 + targetCPUUtilizationPercentage: 70 + behavior: + scaleDown: + stabilizationWindowSeconds: 300 + policies: + - type: Pods + value: 1 + periodSeconds: 120 + +# Give in-flight SSE streams time to drain +terminationGracePeriodSeconds: 90 +lifecycle: + preStop: + exec: + command: ["sh", "-c", "sleep 20"] +``` + +```bash +kubectl create secret generic postgres-credentials \ + --from-literal=password='your-postgres-password' + +kubectl create secret generic bifrost-encryption \ + --from-literal=encryption-key='your-32-byte-encryption-key' + +helm install bifrost bifrost/bifrost -f cluster-values.yaml +``` + +--- + +## Peer Discovery + +Bifrost uses a gossip protocol (memberlist) for peer-to-peer state sync. Configure how peers find each other: + + + + + +Bifrost queries the Kubernetes API to find other Bifrost pods by label selector. No static peer list needed — works with HPA. + +```yaml +bifrost: + cluster: + enabled: true + discovery: + enabled: true + type: kubernetes + k8sNamespace: "default" # namespace where Bifrost runs + k8sLabelSelector: "app.kubernetes.io/name=bifrost" + gossip: + port: 7946 +``` + +The service account needs permission to list pods: + +```yaml +serviceAccount: + create: true + annotations: {} +``` + +```bash +# Create a ClusterRole and binding for pod discovery (apply once) +kubectl apply -f - <<'EOF' +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: bifrost-pod-discovery + namespace: default +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["list", "get", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: bifrost-pod-discovery + namespace: default +subjects: + - kind: ServiceAccount + name: bifrost + namespace: default +roleRef: + kind: Role + name: bifrost-pod-discovery + apiGroup: rbac.authorization.k8s.io +EOF +``` + +```bash +helm install bifrost bifrost/bifrost -f cluster-k8s-discovery-values.yaml +``` + + + + + +Uses a headless service DNS name to resolve peer IPs. Works well with StatefulSets (predictable pod DNS names). + +```yaml +bifrost: + cluster: + enabled: true + discovery: + enabled: true + type: dns + dnsNames: + - "bifrost-headless.default.svc.cluster.local" + gossip: + port: 7946 +``` + +The chart automatically creates a headless service (`bifrost-headless`) when cluster mode is enabled with a StatefulSet. For Deployments, create it manually: + +```bash +kubectl apply -f - <<'EOF' +apiVersion: v1 +kind: Service +metadata: + name: bifrost-headless +spec: + clusterIP: None + selector: + app.kubernetes.io/name: bifrost + ports: + - name: gossip + port: 7946 + protocol: TCP +EOF +``` + +```bash +helm install bifrost bifrost/bifrost -f cluster-dns-discovery-values.yaml +``` + + + + + +Enumerate peer addresses explicitly. Use when discovery mechanisms are unavailable or you want deterministic membership. + +```yaml +bifrost: + cluster: + enabled: true + peers: + - "bifrost-0.bifrost-headless.default.svc.cluster.local:7946" + - "bifrost-1.bifrost-headless.default.svc.cluster.local:7946" + - "bifrost-2.bifrost-headless.default.svc.cluster.local:7946" + gossip: + port: 7946 +``` + + +Static peers require StatefulSet pod names to be stable. This approach doesn't adapt to HPA-driven scaling — use Kubernetes or DNS discovery for dynamic replica counts. + + + + + + +```yaml +bifrost: + cluster: + enabled: true + discovery: + enabled: true + type: consul + consulAddress: "consul.consul.svc.cluster.local:8500" + gossip: + port: 7946 +``` + +```bash +helm install bifrost bifrost/bifrost -f cluster-consul-discovery-values.yaml +``` + + + + + +```yaml +bifrost: + cluster: + enabled: true + discovery: + enabled: true + type: etcd + etcdEndpoints: + - "http://etcd-0.etcd.default.svc.cluster.local:2379" + - "http://etcd-1.etcd.default.svc.cluster.local:2379" + - "http://etcd-2.etcd.default.svc.cluster.local:2379" + gossip: + port: 7946 +``` + + + + + +Best for local development or bare-metal clusters where multicast is available. + +```yaml +bifrost: + cluster: + enabled: true + discovery: + enabled: true + type: mdns + mdnsService: "_bifrost._tcp" + gossip: + port: 7946 +``` + + + + + +--- + +## Allowed Address Space + +Restrict gossip to a specific subnet (useful in multi-tenant clusters): + +```yaml +bifrost: + cluster: + discovery: + enabled: true + type: kubernetes + k8sNamespace: "default" + k8sLabelSelector: "app.kubernetes.io/name=bifrost" + allowedAddressSpace: + - "10.0.0.0/8" + - "172.16.0.0/12" +``` + +--- + +## Region-Aware Routing + +Tag replicas with a region identifier for latency-aware routing: + +```yaml +bifrost: + cluster: + enabled: true + region: "us-east-1" +``` + +--- + +## Full HA Production Example + +```yaml +# ha-production-values.yaml +image: + tag: "v1.4.11" + +replicaCount: 3 + +resources: + requests: + cpu: 1000m + memory: 1Gi + limits: + cpu: 4000m + memory: 4Gi + +autoscaling: + enabled: true + minReplicas: 3 + maxReplicas: 15 + targetCPUUtilizationPercentage: 70 + targetMemoryUtilizationPercentage: 75 + behavior: + scaleDown: + stabilizationWindowSeconds: 300 + policies: + - type: Pods + value: 1 + periodSeconds: 120 + scaleUp: + stabilizationWindowSeconds: 30 + +terminationGracePeriodSeconds: 90 +lifecycle: + preStop: + exec: + command: ["sh", "-c", "sleep 20"] + +ingress: + enabled: true + className: nginx + annotations: + cert-manager.io/cluster-issuer: letsencrypt-prod + nginx.ingress.kubernetes.io/proxy-body-size: "100m" + nginx.ingress.kubernetes.io/proxy-read-timeout: "300" + hosts: + - host: bifrost.yourdomain.com + paths: + - path: / + pathType: Prefix + tls: + - secretName: bifrost-tls + hosts: + - bifrost.yourdomain.com + +storage: + mode: postgres + +postgresql: + external: + enabled: true + host: "rds.us-east-1.amazonaws.com" + port: 5432 + user: bifrost + database: bifrost + sslMode: require + existingSecret: "postgres-credentials" + passwordKey: "password" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" + + client: + initialPoolSize: 1000 + dropExcessRequests: true + enableLogging: true + enforceGovernanceHeader: true + + cluster: + enabled: true + region: "us-east-1" + discovery: + enabled: true + type: kubernetes + k8sNamespace: "default" + k8sLabelSelector: "app.kubernetes.io/name=bifrost" + gossip: + port: 7946 + config: + timeoutSeconds: 10 + successThreshold: 3 + failureThreshold: 3 + + plugins: + telemetry: + enabled: true + config: + push_gateway: + enabled: true + push_gateway_url: "http://prometheus-pushgateway.monitoring.svc.cluster.local:9091" + push_interval: 15 + logging: + enabled: true + governance: + enabled: true + config: + is_vk_mandatory: true + +affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchLabels: + app.kubernetes.io/name: bifrost + topologyKey: kubernetes.io/hostname + +serviceAccount: + create: true + annotations: {} +``` + +```bash +# Prerequisites +kubectl create secret generic postgres-credentials \ + --from-literal=password='your-secure-postgres-password' + +kubectl create secret generic bifrost-encryption \ + --from-literal=encryption-key='your-32-byte-encryption-key' + +# RBAC for Kubernetes pod discovery +kubectl apply -f - <<'EOF' +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: bifrost-pod-discovery + namespace: default +rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["list", "get", "watch"] +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: bifrost-pod-discovery + namespace: default +subjects: + - kind: ServiceAccount + name: bifrost + namespace: default +roleRef: + kind: Role + name: bifrost-pod-discovery + apiGroup: rbac.authorization.k8s.io +EOF + +# Install +helm install bifrost bifrost/bifrost -f ha-production-values.yaml + +# Verify all peers have found each other (check logs) +kubectl logs -l app.kubernetes.io/name=bifrost --tail=50 | grep -i gossip +``` + +--- + +## Verifying Cluster Health + +```bash +# Check all pods are running +kubectl get pods -l app.kubernetes.io/name=bifrost + +# Check gossip port is reachable between pods +kubectl exec -it bifrost-0 -- nc -zv bifrost-1.bifrost-headless 7946 + +# Check health endpoint +kubectl port-forward svc/bifrost 8080:8080 & +curl http://localhost:8080/health + +# View HPA status +kubectl get hpa bifrost + +# Scale manually during maintenance +kubectl scale deployment bifrost --replicas=5 +``` diff --git a/docs/deployment-guides/helm/governance.mdx b/docs/deployment-guides/helm/governance.mdx new file mode 100644 index 0000000000..3679d214d4 --- /dev/null +++ b/docs/deployment-guides/helm/governance.mdx @@ -0,0 +1,422 @@ +--- +title: "Governance" +description: "Configure Bifrost governance in Helm — budgets, rate limits, virtual keys, routing rules, and admin authentication" +icon: "shield" +--- + +Governance lets you control who can call which providers, how much they can spend, how fast they can go, and how traffic is routed. Everything is declared under `bifrost.governance` in your values file and seeded into the database at startup. + + +The governance **plugin** must also be enabled for enforcement to take effect: + +```yaml +bifrost: + plugins: + governance: + enabled: true +``` + +See the [Plugins](/deployment-guides/helm/plugins) page for plugin configuration details. + + +--- + +## Admin Authentication + +Protect the Bifrost dashboard and management API with username/password auth. + +```bash +kubectl create secret generic bifrost-admin-credentials \ + --from-literal=username='admin' \ + --from-literal=password='your-secure-admin-password' +``` + +```yaml +bifrost: + governance: + authConfig: + isEnabled: true + disableAuthOnInference: false # keep auth on inference routes + existingSecret: "bifrost-admin-credentials" + usernameKey: "username" + passwordKey: "password" +``` + +```bash +helm upgrade bifrost bifrost/bifrost --reuse-values -f governance-auth-values.yaml +``` + +--- + +## Budgets + +Spending caps that reset on a configurable period. Budgets are referenced by ID from virtual keys, teams, customers, or providers. + +| Reset duration | Syntax | +|----------------|--------| +| 30 seconds | `"30s"` | +| 5 minutes | `"5m"` | +| 1 hour | `"1h"` | +| 1 day | `"1d"` | +| 1 week | `"1w"` | +| 1 month | `"1M"` | +| 1 year | `"1Y"` | + +```yaml +bifrost: + governance: + budgets: + - id: "budget-dev" + max_limit: 50 # $50 per month + reset_duration: "1M" + + - id: "budget-production" + max_limit: 500 # $500 per month + reset_duration: "1M" + + - id: "budget-testing" + max_limit: 10 # $10 per day + reset_duration: "1d" + + - id: "budget-enterprise" + max_limit: 5000 # $5000 per month + reset_duration: "1M" +``` + +--- + +## Rate Limits + +Token and request-count caps per time window. Referenced by ID from virtual keys, teams, customers, or providers. + +```yaml +bifrost: + governance: + rateLimits: + - id: "rate-limit-standard" + token_max_limit: 100000 # 100K tokens per hour + token_reset_duration: "1h" + request_max_limit: 1000 # 1000 requests per hour + request_reset_duration: "1h" + + - id: "rate-limit-high" + token_max_limit: 500000 # 500K tokens per hour + token_reset_duration: "1h" + request_max_limit: 5000 + request_reset_duration: "1h" + + - id: "rate-limit-burst" + token_max_limit: 50000 # 50K tokens per minute (burst) + token_reset_duration: "1m" + request_max_limit: 500 + request_reset_duration: "1m" + + - id: "rate-limit-testing" + token_max_limit: 10000 + token_reset_duration: "1h" + request_max_limit: 100 + request_reset_duration: "1h" +``` + +--- + +## Customers & Teams + +Optional organizational hierarchy. Virtual keys can be assigned to customers or teams, inheriting their budgets and rate limits. + +```yaml +bifrost: + governance: + customers: + - id: "customer-acme" + name: "Acme Corp" + budget_id: "budget-production" + rate_limit_id: "rate-limit-high" + + - id: "customer-startup" + name: "Startup Inc" + budget_id: "budget-dev" + rate_limit_id: "rate-limit-standard" + + teams: + - id: "team-platform" + name: "Platform Team" + customer_id: "customer-acme" + budget_id: "budget-enterprise" + rate_limit_id: "rate-limit-high" + + - id: "team-ml" + name: "ML Team" + customer_id: "customer-acme" + budget_id: "budget-production" + rate_limit_id: "rate-limit-standard" +``` + +--- + +## Virtual Keys + +Virtual keys are the primary access tokens issued to callers. They scope which providers, models, and underlying API keys are accessible. + +```yaml +bifrost: + governance: + virtualKeys: + # 1. Unrestricted dev key — access to every provider + - id: "vk-dev-all" + name: "Dev: all providers" + value: "vk-dev-all-secret-token" + is_active: true + budget_id: "budget-dev" + rate_limit_id: "rate-limit-standard" + # No provider_configs → all providers allowed + + # 2. OpenAI only — restricted to two models + - id: "vk-openai-prod" + name: "OpenAI Production" + value: "vk-openai-prod-secret-token" + is_active: true + budget_id: "budget-production" + rate_limit_id: "rate-limit-high" + provider_configs: + - provider: "openai" + weight: 1 + allowed_models: ["gpt-4o", "gpt-4o-mini"] + # No keys[] → all configured OpenAI keys allowed + + # 3. Multi-provider with weighted routing + - id: "vk-multi" + name: "Multi-provider weighted" + value: "vk-multi-secret-token" + is_active: true + budget_id: "budget-production" + rate_limit_id: "rate-limit-high" + provider_configs: + - provider: "openai" + weight: 2 # 50% + allowed_models: ["*"] + - provider: "anthropic" + weight: 1 # 25% + allowed_models: ["*"] + - provider: "groq" + weight: 1 # 25% + allowed_models: ["*"] + + # 4. Team-scoped key + - id: "vk-platform-team" + name: "Platform Team Key" + value: "vk-platform-team-token" + is_active: true + team_id: "team-platform" # inherits team budget/rate-limit + provider_configs: + - provider: "openai" + weight: 1 + allowed_models: ["*"] + keys: + - name: "openai-primary" # pin to specific configured key + + # 5. Restricted testing key + - id: "vk-testing" + name: "Testing (gpt-4o-mini only)" + value: "vk-testing-token" + is_active: true + budget_id: "budget-testing" + rate_limit_id: "rate-limit-testing" + provider_configs: + - provider: "openai" + weight: 1 + allowed_models: ["gpt-4o-mini"] + + # 6. Batch API key + - id: "vk-batch" + name: "Batch API workloads" + value: "vk-batch-token" + is_active: true + budget_id: "budget-production" + rate_limit_id: "rate-limit-burst" + provider_configs: + - provider: "openai" + weight: 1 + allowed_models: ["*"] + keys: + - name: "openai-batch" # only the batch-flagged key +``` + +**Use a virtual key in API calls:** + +```bash +curl http://localhost:8080/v1/chat/completions \ + -H "x-bf-vk: vk-openai-prod-secret-token" \ + -H "Content-Type: application/json" \ + -d '{"model":"gpt-4o","messages":[{"role":"user","content":"Hello"}]}' +``` + +--- + +## Model Configs + +Apply budgets and rate limits at the model level, independent of virtual keys: + +```yaml +bifrost: + governance: + modelConfigs: + - id: "model-gpt4o" + model_name: "gpt-4o" + provider: "openai" + budget_id: "budget-production" + rate_limit_id: "rate-limit-high" + + - id: "model-claude" + model_name: "claude-3-5-sonnet-20241022" + provider: "anthropic" + rate_limit_id: "rate-limit-standard" +``` + +--- + +## Provider Governance + +Apply budgets and rate limits at the provider level: + +```yaml +bifrost: + governance: + providers: + - name: "openai" + budget_id: "budget-production" + rate_limit_id: "rate-limit-high" + send_back_raw_request: false + send_back_raw_response: false + + - name: "anthropic" + budget_id: "budget-production" + rate_limit_id: "rate-limit-standard" +``` + +--- + +## Routing Rules + +CEL-expression-based routing rules redirect requests to different providers or models based on request attributes. + +| Field | Description | +|-------|-------------| +| `cel_expression` | CEL expression evaluated against the request; if `true`, rule fires | +| `targets` | Provider/model targets with weights | +| `fallbacks` | Providers to try if all targets fail | +| `scope` | `global`, `team`, `customer`, or `virtual_key` | +| `scope_id` | Required for non-global scopes | +| `priority` | Lower number = evaluated first | + +```yaml +bifrost: + governance: + routingRules: + # Route all GPT requests to Azure + - id: "route-gpt-to-azure" + name: "GPT → Azure" + description: "Route all GPT model requests to Azure OpenAI" + enabled: true + cel_expression: "model.startsWith('gpt-')" + targets: + - provider: "azure" + model: "" # empty = use original model name + weight: 1.0 + fallbacks: ["openai"] + scope: "global" + priority: 0 + + # Route heavy models to a slower but cheaper provider + - id: "route-heavy-to-groq" + name: "Large context → Groq" + enabled: true + cel_expression: "model == 'gpt-4o' && request_body.max_tokens > 4000" + targets: + - provider: "groq" + model: "llama-3.3-70b-versatile" + weight: 1.0 + fallbacks: ["openai"] + scope: "global" + priority: 1 + + # Team-scoped rule + - id: "route-ml-team-bedrock" + name: "ML Team → Bedrock" + enabled: true + cel_expression: "true" # match all requests for this scope + targets: + - provider: "bedrock" + model: "" + weight: 1.0 + fallbacks: ["openai"] + scope: "team" + scope_id: "team-ml" + priority: 0 +``` + +--- + +## Full Example + +```yaml +# governance-full-values.yaml +image: + tag: "v1.4.11" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" + + plugins: + governance: + enabled: true + config: + is_vk_mandatory: true + + governance: + authConfig: + isEnabled: true + existingSecret: "bifrost-admin-credentials" + usernameKey: "username" + passwordKey: "password" + + budgets: + - id: "budget-production" + max_limit: 500 + reset_duration: "1M" + - id: "budget-dev" + max_limit: 50 + reset_duration: "1M" + + rateLimits: + - id: "rate-limit-standard" + token_max_limit: 100000 + token_reset_duration: "1h" + request_max_limit: 1000 + request_reset_duration: "1h" + + virtualKeys: + - id: "vk-production" + name: "Production" + value: "vk-prod-secret-token" + is_active: true + budget_id: "budget-production" + rate_limit_id: "rate-limit-standard" + provider_configs: + - provider: "openai" + weight: 1 + allowed_models: ["gpt-4o", "gpt-4o-mini"] +``` + +```bash +kubectl create secret generic bifrost-encryption \ + --from-literal=encryption-key='your-32-byte-key' + +kubectl create secret generic bifrost-admin-credentials \ + --from-literal=username='admin' \ + --from-literal=password='secure-admin-password' + +helm install bifrost bifrost/bifrost -f governance-full-values.yaml +``` diff --git a/docs/deployment-guides/helm/guardrails.mdx b/docs/deployment-guides/helm/guardrails.mdx new file mode 100644 index 0000000000..4604b426e4 --- /dev/null +++ b/docs/deployment-guides/helm/guardrails.mdx @@ -0,0 +1,262 @@ +--- +title: "Guardrails" +description: "Configure guardrails providers and rules in Bifrost Helm deployments" +icon: "shield-halved" +--- + + +Guardrails are an **enterprise-only** feature. They require the enterprise Bifrost image. + + +Guardrails are configured under `bifrost.guardrails` in your values file. The configuration has two parts: + +- **`providers`** — the backend that performs the check. Rules link to providers by `id`. +- **`rules`** — CEL expressions that control when and where providers are invoked. + +--- + +## Providers + + + + +Runs entirely in-process with no external dependency. Patterns use RE2 syntax. Supports optional per-pattern flags: `i` (case-insensitive), `m` (multiline), `s` (dot-all). + +```yaml +bifrost: + guardrails: + providers: + - id: 1 + provider_name: "regex" + policy_name: "block-secrets" + enabled: true + timeout: 5 + config: + patterns: + - pattern: "sk-[A-Za-z0-9]{20,}" + description: "OpenAI API key" + - pattern: "AKIA[0-9A-Z]{16}" + description: "AWS access key" + flags: "i" + - pattern: "gh[ps]_[A-Za-z0-9]{36}" + description: "GitHub token" +``` + + + + +```yaml +bifrost: + guardrails: + providers: + - id: 2 + provider_name: "bedrock" + policy_name: "content-filter" + enabled: true + timeout: 15 + config: + guardrail_arn: "arn:aws:bedrock:us-east-1::guardrail/abc123" + guardrail_version: "DRAFT" # or a published version number + region: "us-east-1" + access_key: "env.AWS_ACCESS_KEY_ID" # omit to use instance role + secret_key: "env.AWS_SECRET_ACCESS_KEY" +``` + + + + +```yaml +bifrost: + guardrails: + providers: + - id: 3 + provider_name: "azure" + policy_name: "azure-content-safety" + enabled: true + timeout: 10 + config: + endpoint: "https://your-resource.cognitiveservices.azure.com" + api_key: "env.AZURE_CONTENT_SAFETY_KEY" + analyze_enabled: true + analyze_severity_threshold: "medium" # low | medium | high + jailbreak_shield_enabled: true + indirect_attack_shield_enabled: true + copyright_enabled: false + text_blocklist_enabled: false + blocklist_names: [] +``` + + + + +```yaml +bifrost: + guardrails: + providers: + - id: 4 + provider_name: "grayswan" + policy_name: "grayswan-jailbreak" + enabled: true + timeout: 15 + config: + api_key: "env.GRAYSWAN_API_KEY" + violation_threshold: 0.7 # 0.0–1.0; higher = more permissive + reasoning_mode: "standard" # standard | fast + policy_id: "" # optional: single policy ID + policy_ids: [] # optional: multiple policy IDs + rules: {} # optional: inline rule map +``` + + + + +--- + +## Rules + +Rules are CEL expressions that fire when their condition is met. Available CEL variables: + +| Variable | Type | Description | +|----------|------|-------------| +| `model` | `string` | Model name from the request | +| `provider` | `string` | Provider name (e.g. `"openai"`) | +| `headers` | `map` | HTTP request headers | +| `params` | `map` | Query parameters | +| `customer` | `string` | Customer ID | +| `team` | `string` | Team ID | +| `user` | `string` | User ID | + +Rule fields: + +| Field | Required | Description | +|-------|----------|-------------| +| `id` | Yes | Unique integer ID | +| `name` | Yes | Human-readable name | +| `description` | No | Optional description | +| `enabled` | Yes | `true` to activate | +| `cel_expression` | Yes | CEL boolean expression; `"true"` matches all requests | +| `apply_to` | Yes | `"input"`, `"output"`, or `"both"` | +| `sampling_rate` | No | `0`–`100`; percentage of requests to check (default: 100) | +| `timeout` | No | Rule timeout in seconds | +| `provider_config_ids` | No | Provider `id`s to invoke when this rule matches | + +```yaml +bifrost: + guardrails: + rules: + - id: 101 + name: "block-secrets-input" + description: "Block prompts containing API keys" + enabled: true + cel_expression: "true" + apply_to: "input" + sampling_rate: 100 + timeout: 10 + provider_config_ids: [1] + + - id: 102 + name: "azure-output-gpt4o" + description: "Scan GPT-4o responses" + enabled: true + cel_expression: "model == 'gpt-4o'" + apply_to: "output" + sampling_rate: 100 + timeout: 15 + provider_config_ids: [3] + + - id: 103 + name: "grayswan-openai-input" + enabled: true + cel_expression: "provider == 'openai'" + apply_to: "input" + sampling_rate: 50 + timeout: 20 + provider_config_ids: [4] + + - id: 104 + name: "strict-team-check" + enabled: true + cel_expression: "team == 'team-platform'" + apply_to: "both" + sampling_rate: 100 + timeout: 30 + provider_config_ids: [1, 3] # multiple providers run in parallel +``` + +--- + +## Full example + +```yaml +# guardrails-values.yaml +image: + tag: "latest" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" + + guardrails: + providers: + - id: 1 + provider_name: "regex" + policy_name: "block-secrets" + enabled: true + timeout: 5 + config: + patterns: + - pattern: "sk-[A-Za-z0-9]{20,}" + description: "OpenAI API key" + - pattern: "AKIA[0-9A-Z]{16}" + description: "AWS access key" + - pattern: "gh[ps]_[A-Za-z0-9]{36}" + description: "GitHub token" + + - id: 2 + provider_name: "azure" + policy_name: "content-safety" + enabled: true + timeout: 10 + config: + endpoint: "https://your-resource.cognitiveservices.azure.com" + api_key: "env.AZURE_CONTENT_SAFETY_KEY" + analyze_enabled: true + analyze_severity_threshold: "medium" + jailbreak_shield_enabled: true + indirect_attack_shield_enabled: false + copyright_enabled: false + text_blocklist_enabled: false + + rules: + - id: 101 + name: "block-secrets-input" + description: "Block prompts leaking credentials" + enabled: true + cel_expression: "true" + apply_to: "input" + sampling_rate: 100 + timeout: 10 + provider_config_ids: [1] + + - id: 102 + name: "content-safety-both" + description: "Azure content safety on input and output" + enabled: true + cel_expression: "true" + apply_to: "both" + sampling_rate: 100 + timeout: 15 + provider_config_ids: [2] +``` + +```bash +kubectl create secret generic azure-content-safety \ + --from-literal=key='your-azure-content-safety-api-key' + +helm install bifrost bifrost/bifrost \ + -f guardrails-values.yaml \ + --set env[0].name=AZURE_CONTENT_SAFETY_KEY \ + --set env[0].valueFrom.secretKeyRef.name=azure-content-safety \ + --set env[0].valueFrom.secretKeyRef.key=key +``` \ No newline at end of file diff --git a/docs/deployment-guides/helm/plugins.mdx b/docs/deployment-guides/helm/plugins.mdx new file mode 100644 index 0000000000..79a4c4f788 --- /dev/null +++ b/docs/deployment-guides/helm/plugins.mdx @@ -0,0 +1,537 @@ +--- +title: "Plugins" +description: "Configure Bifrost plugins in Helm — telemetry, logging, semantic cache, OpenTelemetry, Datadog, governance, and custom plugins" +icon: "puzzle-piece" +--- + +Plugins are configured under `bifrost.plugins`. Each plugin is independently enabled/disabled. Pre-hooks run in registration order; post-hooks run in reverse order. + + +**Telemetry, logging, and governance are auto-loaded built-ins** — they are always active and do not need to be explicitly enabled. Their configuration lives in `bifrost.client.*` and `bifrost.governance.*`, not in the `plugins` block. + +The `plugins` block controls the opt-in plugins: `semanticCache`, `otel`, `datadog`, `maxim`, and custom plugins. + + +```yaml +bifrost: + plugins: + semanticCache: + enabled: false + otel: + enabled: false + datadog: + enabled: false +``` + +```bash +# Enable an opt-in plugin at install time +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + --set bifrost.plugins.otel.enabled=true + +# Or upgrade to enable a plugin without touching other values +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set bifrost.plugins.semanticCache.enabled=true +``` + +--- + + + + + +### Telemetry (Prometheus) + + +Telemetry is **always active** — it cannot be disabled. You do not need to set `bifrost.plugins.telemetry.enabled`. + + +Exposes Prometheus metrics at `GET /metrics`. Custom labels are set via `bifrost.client.prometheusLabels`: + +```yaml +bifrost: + client: + prometheusLabels: + - "environment=production" + - "region=us-east-1" +``` + +```bash +# Verify metrics are exposed +kubectl port-forward svc/bifrost 8080:8080 & +curl http://localhost:8080/metrics | head -30 +``` + +**With Prometheus Push Gateway** (recommended for multi-replica / HA setups where pull-based scraping can miss pods): + +```yaml +bifrost: + plugins: + telemetry: + enabled: true + config: + push_gateway: + enabled: true + push_gateway_url: "http://prometheus-pushgateway.monitoring.svc.cluster.local:9091" + job_name: "bifrost" + instance_id: "" # auto-derived from pod name if empty + push_interval: 15 + basic_auth: + username: "" + password: "" +``` + +**ServiceMonitor for Prometheus Operator:** + +```yaml +serviceMonitor: + enabled: true + interval: 30s + scrapeTimeout: 10s + namespace: monitoring # namespace where Prometheus is deployed +``` + + + + + +### Request/Response Logging + + +Logging is **auto-loaded** when `bifrost.client.enableLogging: true` and a log store is configured. You do not need to set `bifrost.plugins.logging.enabled`. + + +Configure logging via the `client` block: + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.enableLogging` | Enable request/response logging | `true` | +| `bifrost.client.disableContentLogging` | Strip message body from logs (HIPAA/PCI) | `false` | +| `bifrost.client.loggingHeaders` | HTTP headers to capture in log metadata | `[]` | + +```yaml +bifrost: + client: + enableLogging: true + disableContentLogging: false # set true for HIPAA/compliance + loggingHeaders: + - "x-request-id" + - "x-user-id" + - "x-team-id" +``` + +```bash +# Verify logs are being written +kubectl port-forward svc/bifrost 8080:8080 & +curl -s "http://localhost:8080/api/logs?limit=5" | jq . +``` + +See [Client Configuration](/deployment-guides/helm/client) for the full reference. + + + + + +### Governance + + +Governance is **always active** for OSS deployments. You do not need to set `bifrost.plugins.governance.enabled`. + + +Virtual key enforcement is controlled by the `client` block: + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.client.enforceAuthOnInference` | Require a virtual key (`x-bf-vk`) on every inference request | `false` | + +```yaml +bifrost: + client: + enforceAuthOnInference: true # require virtual key on all inference requests +``` + +Define virtual keys, budgets, rate limits, and routing rules in `bifrost.governance.*`. See the [Governance](/deployment-guides/helm/governance) page. + + + + + +### Semantic Cache + +Caches LLM responses using vector similarity so semantically equivalent prompts return cached answers. + +Two modes: +- **Semantic mode** (`dimension > 1`): uses an embedding model + vector store for similarity search +- **Direct / hash mode** (`dimension: 1`): exact-match hash-based caching, no embedding model needed + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.plugins.semanticCache.enabled` | Enable semantic caching | `false` | +| `bifrost.plugins.semanticCache.config.provider` | Embedding provider | `"openai"` | +| `bifrost.plugins.semanticCache.config.embedding_model` | Embedding model name | `"text-embedding-3-small"` | +| `bifrost.plugins.semanticCache.config.dimension` | Embedding dimension (`1` = direct/hash mode) | `1536` | +| `bifrost.plugins.semanticCache.config.threshold` | Cosine similarity threshold (0–1) | `0.8` | +| `bifrost.plugins.semanticCache.config.ttl` | Cache entry TTL (Go duration) | `"5m"` | +| `bifrost.plugins.semanticCache.config.conversation_history_threshold` | Number of past messages to include in cache key | `3` | +| `bifrost.plugins.semanticCache.config.cache_by_model` | Include model name in cache key | `true` | +| `bifrost.plugins.semanticCache.config.cache_by_provider` | Include provider name in cache key | `true` | +| `bifrost.plugins.semanticCache.config.exclude_system_prompt` | Exclude system prompt from cache key | `false` | +| `bifrost.plugins.semanticCache.config.cleanup_on_shutdown` | Delete cache data on pod shutdown | `false` | + +**Semantic mode (with OpenAI embeddings + Weaviate):** + +```bash +kubectl create secret generic semantic-cache-secret \ + --from-literal=openai-key='sk-your-openai-embedding-key' +``` + +```yaml +# semantic-cache-values.yaml +image: + tag: "v1.4.11" + +vectorStore: + enabled: true + type: weaviate + weaviate: + enabled: true + persistence: + size: 20Gi + +bifrost: + plugins: + semanticCache: + enabled: true + config: + provider: "openai" + keys: + - value: "env.SEMANTIC_CACHE_OPENAI_KEY" + weight: 1 + embedding_model: "text-embedding-3-small" + dimension: 1536 + threshold: 0.85 + ttl: "1h" + conversation_history_threshold: 5 + cache_by_model: true + cache_by_provider: true + + providerSecrets: + semantic-cache-key: + existingSecret: "semantic-cache-secret" + key: "openai-key" + envVar: "SEMANTIC_CACHE_OPENAI_KEY" +``` + +```bash +helm install bifrost bifrost/bifrost -f semantic-cache-values.yaml +``` + +**Direct / hash mode** (no embedding provider needed): + +```yaml +bifrost: + plugins: + semanticCache: + enabled: true + config: + dimension: 1 # triggers hash-based exact matching + ttl: "30m" + cache_by_model: true + cache_by_provider: true +``` + + +The vector store (`vectorStore.*`) must be configured and enabled for semantic mode. Direct/hash mode works without a vector store but still requires a storage backend. + + + + + + +### OpenTelemetry (OTel) + +Sends distributed traces and push-based metrics to any OTLP-compatible collector (Jaeger, Tempo, Honeycomb, etc.). + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.plugins.otel.enabled` | Enable OTel tracing | `false` | +| `bifrost.plugins.otel.config.service_name` | Service name in traces | `"bifrost"` | +| `bifrost.plugins.otel.config.collector_url` | OTLP collector endpoint | `""` | +| `bifrost.plugins.otel.config.trace_type` | Trace type (`genai_extension` or `default`) | `"genai_extension"` | +| `bifrost.plugins.otel.config.protocol` | Transport protocol (`grpc` or `http`) | `"grpc"` | +| `bifrost.plugins.otel.config.metrics_enabled` | Enable OTLP push-based metrics | `false` | +| `bifrost.plugins.otel.config.metrics_endpoint` | OTLP metrics endpoint | `""` | +| `bifrost.plugins.otel.config.metrics_push_interval` | Push interval in seconds | `15` | +| `bifrost.plugins.otel.config.headers` | Custom headers for the collector | `{}` | +| `bifrost.plugins.otel.config.insecure` | Skip TLS verification | `false` | +| `bifrost.plugins.otel.config.tls_ca_cert` | Path to CA cert for TLS | `""` | + +```yaml +# otel-values.yaml +image: + tag: "v1.4.11" + +bifrost: + plugins: + otel: + enabled: true + config: + service_name: "bifrost-production" + collector_url: "otel-collector.observability.svc.cluster.local:4317" + trace_type: "genai_extension" + protocol: "grpc" + insecure: true # set false in production with a proper cert + metrics_enabled: true + metrics_endpoint: "otel-collector.observability.svc.cluster.local:4317" + metrics_push_interval: 15 + headers: + x-honeycomb-team: "env.HONEYCOMB_API_KEY" +``` + +```bash +helm upgrade bifrost bifrost/bifrost --reuse-values -f otel-values.yaml +``` + +**With authentication headers from a Kubernetes Secret:** + +```bash +kubectl create secret generic otel-credentials \ + --from-literal=api-key='your-honeycomb-or-grafana-key' +``` + +```yaml +bifrost: + plugins: + otel: + enabled: true + config: + collector_url: "api.honeycomb.io:443" + protocol: "grpc" + headers: + x-honeycomb-team: "env.OTEL_API_KEY" + + providerSecrets: + otel-key: + existingSecret: "otel-credentials" + key: "api-key" + envVar: "OTEL_API_KEY" +``` + + + + + +### Datadog APM + +Sends traces to a Datadog Agent running in the cluster. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.plugins.datadog.enabled` | Enable Datadog tracing | `false` | +| `bifrost.plugins.datadog.config.service_name` | Service name | `"bifrost"` | +| `bifrost.plugins.datadog.config.agent_addr` | Datadog Agent address | `"localhost:8126"` | +| `bifrost.plugins.datadog.config.env` | Deployment environment tag | `""` | +| `bifrost.plugins.datadog.config.version` | Version tag | `""` | +| `bifrost.plugins.datadog.config.enable_traces` | Enable trace collection | `true` | +| `bifrost.plugins.datadog.config.custom_tags` | Extra tags on all spans | `{}` | + +The Datadog Agent is typically deployed via the [Datadog Helm chart](https://docs.datadoghq.com/containers/kubernetes/installation/) as a DaemonSet, making it available at the node's hostIP. + +```yaml +# datadog-values.yaml +image: + tag: "v1.4.11" + +bifrost: + plugins: + datadog: + enabled: true + config: + service_name: "bifrost" + agent_addr: "$(HOST_IP):8126" # uses Datadog DaemonSet pattern + env: "production" + version: "v1.4.11" + enable_traces: true + custom_tags: + team: "platform" + region: "us-east-1" + +# Inject HOST_IP so Bifrost can reach the DaemonSet agent on the same node +env: + - name: HOST_IP + valueFrom: + fieldRef: + fieldPath: status.hostIP +``` + +```bash +helm upgrade bifrost bifrost/bifrost --reuse-values -f datadog-values.yaml +``` + + + + + +### Maxim Observability + +Sends LLM request/response data to [Maxim](https://getmaxim.ai) for tracing, evaluation, and observability. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.plugins.maxim.enabled` | Enable Maxim plugin | `false` | +| `bifrost.plugins.maxim.config.api_key` | Maxim API key (plain text, prefer secret) | `""` | +| `bifrost.plugins.maxim.config.log_repo_id` | Maxim log repository ID | `""` | +| `bifrost.plugins.maxim.secretRef.name` | Kubernetes Secret name for API key | `""` | +| `bifrost.plugins.maxim.secretRef.key` | Key within the secret | `"api-key"` | + +```bash +kubectl create secret generic maxim-credentials \ + --from-literal=api-key='your-maxim-api-key' +``` + +```yaml +# maxim-values.yaml +image: + tag: "v1.4.11" + +bifrost: + plugins: + maxim: + enabled: true + config: + log_repo_id: "your-log-repo-id" + secretRef: + name: "maxim-credentials" + key: "api-key" +``` + +```bash +helm upgrade bifrost bifrost/bifrost --reuse-values -f maxim-values.yaml +``` + + + + + +### Custom / Dynamic Plugins + +Load a custom Go plugin (compiled `.so` file) at runtime. + +```yaml +bifrost: + plugins: + custom: + - name: "my-custom-plugin" + enabled: true + path: "/plugins/my-plugin.so" + version: 1 + config: + api_endpoint: "https://my-service.example.com" + timeout: 5000 +``` + +Mount the `.so` file via a volume: + +```yaml +volumes: + - name: custom-plugins + configMap: + name: bifrost-custom-plugins + +volumeMounts: + - name: custom-plugins + mountPath: /plugins +``` + +Or use an init container to download the plugin binary: + +```yaml +initContainers: + - name: download-plugin + image: curlimages/curl:8.6.0 + command: + - sh + - -c + - | + curl -fsSL https://plugins.example.com/my-plugin.so \ + -o /plugins/my-plugin.so + volumeMounts: + - name: plugin-dir + mountPath: /plugins + +volumes: + - name: plugin-dir + emptyDir: {} + +volumeMounts: + - name: plugin-dir + mountPath: /plugins +``` + +```bash +helm upgrade bifrost bifrost/bifrost --reuse-values -f custom-plugin-values.yaml +``` + + + + + +--- + +## All Plugins Together + +```yaml +# all-plugins-values.yaml +image: + tag: "v1.4.11" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "encryption-key" + + plugins: + telemetry: + enabled: true + config: + custom_labels: + - name: "environment" + value: "production" + + logging: + enabled: true + config: + disable_content_logging: false + logging_headers: + - "x-request-id" + + governance: + enabled: true + config: + is_vk_mandatory: true + + semanticCache: + enabled: true + config: + provider: "openai" + keys: + - value: "env.CACHE_OPENAI_KEY" + weight: 1 + embedding_model: "text-embedding-3-small" + dimension: 1536 + threshold: 0.85 + ttl: "1h" + + otel: + enabled: true + config: + service_name: "bifrost" + collector_url: "otel-collector.observability.svc.cluster.local:4317" + protocol: "grpc" + insecure: true +``` + +```bash +helm install bifrost bifrost/bifrost -f all-plugins-values.yaml +``` diff --git a/docs/deployment-guides/helm/providers.mdx b/docs/deployment-guides/helm/providers.mdx new file mode 100644 index 0000000000..8a4e0ccc4c --- /dev/null +++ b/docs/deployment-guides/helm/providers.mdx @@ -0,0 +1,941 @@ +--- +title: "Provider Setup" +description: "Configure LLM providers in the Bifrost Helm chart — API keys, cloud-native auth, and self-hosted endpoints" +icon: "plug" +--- + +All providers are configured under `bifrost.providers` in your values file. Each provider entry contains a `keys` list where each key has a `name`, `value`, `weight`, and optional provider-specific config. + +**Two ways to supply credentials:** + +- **Direct value** — `value: "sk-..."` (fine for dev; avoid in production) +- **Kubernetes Secret + env var** — store the key in a Secret, inject as an env var, and reference it with `value: "env.VAR_NAME"` + +The `providerSecrets` block handles the Secret → env var injection automatically: + +```yaml +bifrost: + providers: + openai: + keys: + - name: "primary" + value: "env.OPENAI_API_KEY" # resolved at runtime + weight: 1 + + providerSecrets: + openai: + existingSecret: "my-openai-secret" + key: "api-key" + envVar: "OPENAI_API_KEY" # injected into the pod +``` + +--- + + + + + +### OpenAI + +Supports multiple keys with weighted load balancing. The key with `use_for_batch_api: true` is eligible for the Batch API. + +**Step 1 — Create secret** + +```bash +kubectl create secret generic openai-credentials \ + --from-literal=api-key-1='sk-your-primary-key' \ + --from-literal=api-key-2='sk-your-secondary-key' \ + --from-literal=api-key-batch='sk-your-batch-key' +``` + +**Step 2 — Values file** + +```yaml +# openai-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + openai: + keys: + - name: "openai-primary" + value: "env.OPENAI_KEY_1" + weight: 2 # 50% of traffic + models: ["*"] + - name: "openai-secondary" + value: "env.OPENAI_KEY_2" + weight: 1 # 25% + models: ["gpt-4o-mini"] # restrict to cheaper model + - name: "openai-batch" + value: "env.OPENAI_KEY_BATCH" + weight: 1 # 25% + models: ["*"] + use_for_batch_api: true + + providerSecrets: + openai-key-1: + existingSecret: "openai-credentials" + key: "api-key-1" + envVar: "OPENAI_KEY_1" + openai-key-2: + existingSecret: "openai-credentials" + key: "api-key-2" + envVar: "OPENAI_KEY_2" + openai-key-batch: + existingSecret: "openai-credentials" + key: "api-key-batch" + envVar: "OPENAI_KEY_BATCH" +``` + +**Step 3 — Install** + +```bash +helm install bifrost bifrost/bifrost -f openai-values.yaml +``` + +**Optional — per-provider network config** + +```yaml +bifrost: + providers: + openai: + keys: + - name: "primary" + value: "env.OPENAI_KEY_1" + weight: 1 + network_config: + default_request_timeout_in_seconds: 120 + max_retries: 3 + retry_backoff_initial_ms: 500 + retry_backoff_max_ms: 5000 + max_conns_per_host: 5000 +``` + + + + + +### Anthropic + +```bash +kubectl create secret generic anthropic-credentials \ + --from-literal=api-key-1='sk-ant-your-primary-key' \ + --from-literal=api-key-2='sk-ant-your-secondary-key' +``` + +```yaml +# anthropic-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + anthropic: + keys: + - name: "anthropic-primary" + value: "env.ANTHROPIC_KEY_1" + weight: 1 + models: ["*"] + - name: "anthropic-secondary" + value: "env.ANTHROPIC_KEY_2" + weight: 1 + models: ["*"] + + providerSecrets: + anthropic-key-1: + existingSecret: "anthropic-credentials" + key: "api-key-1" + envVar: "ANTHROPIC_KEY_1" + anthropic-key-2: + existingSecret: "anthropic-credentials" + key: "api-key-2" + envVar: "ANTHROPIC_KEY_2" +``` + +```bash +helm install bifrost bifrost/bifrost -f anthropic-values.yaml +``` + +**Override Anthropic beta headers** (optional): + +```yaml +bifrost: + providers: + anthropic: + keys: + - name: "primary" + value: "env.ANTHROPIC_KEY_1" + weight: 1 + network_config: + beta_header_overrides: + redact-thinking-: true +``` + + + + + +### Azure OpenAI + +Azure requires `azure_key_config` on every key with `endpoint`, `api_version`, and a `deployments` map (logical model name → Azure deployment name). + +Two auth modes are supported: + + + + +**Step 1 — Create secret** + +```bash +kubectl create secret generic azure-credentials \ + --from-literal=api-key='your-azure-openai-api-key' \ + --from-literal=endpoint='https://your-resource.openai.azure.com' +``` + +**Step 2 — Values file** + +```yaml +# azure-apikey-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + azure: + keys: + - name: "azure-primary" + value: "env.AZURE_API_KEY" + weight: 1 + models: ["gpt-4o", "gpt-4o-mini", "text-embedding-3-small"] + azure_key_config: + endpoint: "env.AZURE_ENDPOINT" + api_version: "2024-10-21" + deployments: + gpt-4o: "gpt-4o-prod" + gpt-4o-mini: "gpt-4o-mini-prod" + text-embedding-3-small: "embeddings-prod" + + providerSecrets: + azure-api-key: + existingSecret: "azure-credentials" + key: "api-key" + envVar: "AZURE_API_KEY" + azure-endpoint: + existingSecret: "azure-credentials" + key: "endpoint" + envVar: "AZURE_ENDPOINT" +``` + +**Step 3 — Install** + +```bash +helm install bifrost bifrost/bifrost -f azure-apikey-values.yaml +``` + + + + +When `value` is empty, Bifrost uses `DefaultAzureCredential` — which automatically resolves credentials from: +- AKS Workload Identity (recommended for production) +- Azure VM managed identity +- `az login` (developer machines) + +**Step 1 — Annotate the service account** (AKS Workload Identity) + +```bash +# Associate the Kubernetes service account with your Azure managed identity +kubectl annotate serviceaccount bifrost \ + azure.workload.identity/client-id="" +``` + +```yaml +serviceAccount: + annotations: + azure.workload.identity/client-id: "" +``` + +**Step 2 — Values file** + +```bash +kubectl create secret generic azure-config \ + --from-literal=endpoint='https://your-resource.openai.azure.com' +``` + +```yaml +# azure-msi-values.yaml +image: + tag: "v1.4.11" + +serviceAccount: + annotations: + azure.workload.identity/client-id: "" + +bifrost: + providers: + azure: + keys: + - name: "azure-workload-identity" + value: "" # empty = DefaultAzureCredential + weight: 1 + models: ["gpt-4o"] + azure_key_config: + endpoint: "env.AZURE_ENDPOINT" + api_version: "2024-10-21" + deployments: + gpt-4o: "gpt-4o-prod" + + providerSecrets: + azure-endpoint: + existingSecret: "azure-config" + key: "endpoint" + envVar: "AZURE_ENDPOINT" +``` + +**Step 3 — Install** + +```bash +helm install bifrost bifrost/bifrost -f azure-msi-values.yaml +``` + + + + +**Multi-region failover** (two deployments, different regions): + +```yaml +bifrost: + providers: + azure: + keys: + - name: "eastus" + value: "env.AZURE_KEY_EAST" + weight: 1 + azure_key_config: + endpoint: "env.AZURE_ENDPOINT_EAST" + api_version: "2024-10-21" + deployments: + gpt-4o: "gpt-4o-eastus" + - name: "westus" + value: "env.AZURE_KEY_WEST" + weight: 1 + azure_key_config: + endpoint: "env.AZURE_ENDPOINT_WEST" + api_version: "2024-10-21" + deployments: + gpt-4o: "gpt-4o-westus" +``` + + + + + +### AWS Bedrock + +Bedrock requires `bedrock_key_config` with at minimum a `region`. Three auth modes: + + + + +```bash +kubectl create secret generic aws-credentials \ + --from-literal=access-key-id='AKIAIOSFODNN7EXAMPLE' \ + --from-literal=secret-access-key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' +``` + +```yaml +# bedrock-static-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + bedrock: + keys: + - name: "bedrock-static" + value: "" + weight: 1 + models: ["*"] + bedrock_key_config: + region: "us-east-1" + access_key: "env.AWS_ACCESS_KEY_ID" + secret_key: "env.AWS_SECRET_ACCESS_KEY" + deployments: + # Logical name -> Bedrock inference profile + anthropic.claude-3-5-sonnet: "us.anthropic.claude-3-5-sonnet-20240620-v1:0" + + providerSecrets: + aws-access-key: + existingSecret: "aws-credentials" + key: "access-key-id" + envVar: "AWS_ACCESS_KEY_ID" + aws-secret-key: + existingSecret: "aws-credentials" + key: "secret-access-key" + envVar: "AWS_SECRET_ACCESS_KEY" +``` + +```bash +helm install bifrost bifrost/bifrost -f bedrock-static-values.yaml +``` + + + + +When only `region` is set, Bifrost inherits credentials from the AWS SDK default chain — IRSA (IAM Roles for Service Accounts), EC2 instance profile, or `AWS_*` env vars. + +**Step 1 — Annotate the service account with the IAM role** + +```bash +kubectl annotate serviceaccount bifrost \ + eks.amazonaws.com/role-arn="arn:aws:iam::123456789012:role/BifrostBedrockRole" +``` + +```yaml +serviceAccount: + annotations: + eks.amazonaws.com/role-arn: "arn:aws:iam::123456789012:role/BifrostBedrockRole" +``` + +**Step 2 — Values file** + +```yaml +# bedrock-irsa-values.yaml +image: + tag: "v1.4.11" + +serviceAccount: + annotations: + eks.amazonaws.com/role-arn: "arn:aws:iam::123456789012:role/BifrostBedrockRole" + +bifrost: + providers: + bedrock: + keys: + - name: "bedrock-irsa" + value: "" + weight: 1 + models: ["*"] + bedrock_key_config: + region: "us-east-1" + # No access_key / secret_key — SDK uses IRSA token automatically +``` + +```bash +helm install bifrost bifrost/bifrost -f bedrock-irsa-values.yaml +``` + + + + +Assumes a cross-account role on top of the default credential chain. + +```yaml +# bedrock-assumerole-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + bedrock: + keys: + - name: "bedrock-assumerole" + value: "" + weight: 1 + models: ["*"] + bedrock_key_config: + region: "us-west-2" + # Source identity from pod's default chain, then assume this role + role_arn: "env.AWS_ROLE_ARN" + external_id: "env.AWS_EXTERNAL_ID" + session_name: "bifrost-session" +``` + +```bash +kubectl create secret generic aws-role-config \ + --from-literal=role-arn='arn:aws:iam::999999999999:role/CrossAccountBedrockRole' \ + --from-literal=external-id='your-external-id' +``` + +```yaml + providerSecrets: + aws-role-arn: + existingSecret: "aws-role-config" + key: "role-arn" + envVar: "AWS_ROLE_ARN" + aws-external-id: + existingSecret: "aws-role-config" + key: "external-id" + envVar: "AWS_EXTERNAL_ID" +``` + +```bash +helm install bifrost bifrost/bifrost -f bedrock-assumerole-values.yaml +``` + + + + +**Batch API — S3 configuration** + +```yaml +bedrock_key_config: + region: "us-east-1" + access_key: "env.AWS_ACCESS_KEY_ID" + secret_key: "env.AWS_SECRET_ACCESS_KEY" + batch_s3_config: + buckets: + - bucket_name: "my-bedrock-batch-bucket" + prefix: "batch/" + is_default: true +``` + + + + + +### Google Vertex AI + +Vertex requires `vertex_key_config` with `project_id` and `region`. Two auth modes: + + + + +```bash +# Base64-encode the service account JSON +SA_JSON=$(cat service-account-key.json | base64 -w 0) + +kubectl create secret generic gcp-credentials \ + --from-literal=service-account-json="${SA_JSON}" +``` + +```yaml +# vertex-sa-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + vertex: + keys: + - name: "vertex-sa-key" + value: "" + weight: 1 + models: ["*"] + vertex_key_config: + project_id: "env.VERTEX_PROJECT_ID" + region: "us-central1" + auth_credentials: "env.VERTEX_AUTH_CREDENTIALS" + + providerSecrets: + vertex-project-id: + existingSecret: "gcp-credentials" + key: "project-id" + envVar: "VERTEX_PROJECT_ID" + vertex-sa: + existingSecret: "gcp-credentials" + key: "service-account-json" + envVar: "VERTEX_AUTH_CREDENTIALS" +``` + +```bash +helm install bifrost bifrost/bifrost -f vertex-sa-values.yaml +``` + + + + +When `auth_credentials` is omitted, Bifrost calls `google.FindDefaultCredentials` — which resolves to: +- GKE Workload Identity (recommended) +- GCE metadata server (on Compute Engine / Cloud Run) +- `GOOGLE_APPLICATION_CREDENTIALS` path +- `gcloud auth application-default login` (developer machines) + +**Step 1 — Annotate the service account** (GKE Workload Identity) + +```bash +gcloud iam service-accounts add-iam-policy-binding \ + bifrost-sa@my-project.iam.gserviceaccount.com \ + --role roles/iam.workloadIdentityUser \ + --member "serviceAccount:my-project.svc.id.goog[default/bifrost]" +``` + +```yaml +serviceAccount: + annotations: + iam.gke.io/gcp-service-account: "bifrost-sa@my-project.iam.gserviceaccount.com" +``` + +**Step 2 — Values file** + +```yaml +# vertex-wli-values.yaml +image: + tag: "v1.4.11" + +serviceAccount: + annotations: + iam.gke.io/gcp-service-account: "bifrost-sa@my-project.iam.gserviceaccount.com" + +bifrost: + providers: + vertex: + keys: + - name: "vertex-workload-identity" + value: "" + weight: 1 + models: ["*"] + vertex_key_config: + project_id: "my-gcp-project" + region: "us-central1" + # auth_credentials intentionally omitted → ADC lookup +``` + +```bash +helm install bifrost bifrost/bifrost -f vertex-wli-values.yaml +``` + + + + + + + + +### Standard API-Key Providers + +These providers follow the same simple pattern — one or more keys with weights. + + + + +```bash +kubectl create secret generic groq-credentials \ + --from-literal=api-key='gsk_your_groq_api_key' +``` + +```yaml +bifrost: + providers: + groq: + keys: + - name: "groq-primary" + value: "env.GROQ_API_KEY" + weight: 1 + models: ["*"] + + providerSecrets: + groq-key: + existingSecret: "groq-credentials" + key: "api-key" + envVar: "GROQ_API_KEY" +``` + + + + +```bash +kubectl create secret generic gemini-credentials \ + --from-literal=api-key='your-gemini-api-key' +``` + +```yaml +bifrost: + providers: + gemini: + keys: + - name: "gemini-main" + value: "env.GEMINI_API_KEY" + weight: 1 + models: ["*"] + + providerSecrets: + gemini-key: + existingSecret: "gemini-credentials" + key: "api-key" + envVar: "GEMINI_API_KEY" +``` + + + + +```bash +kubectl create secret generic mistral-credentials \ + --from-literal=api-key='your-mistral-api-key' +``` + +```yaml +bifrost: + providers: + mistral: + keys: + - name: "mistral-main" + value: "env.MISTRAL_API_KEY" + weight: 1 + models: ["*"] + + providerSecrets: + mistral-key: + existingSecret: "mistral-credentials" + key: "api-key" + envVar: "MISTRAL_API_KEY" +``` + + + + +All standard API-key providers follow the same pattern. Replace the provider name and env var name accordingly: + +```yaml +bifrost: + providers: + cohere: + keys: + - name: "cohere-main" + value: "env.COHERE_API_KEY" + weight: 1 + perplexity: + keys: + - name: "perplexity-main" + value: "env.PERPLEXITY_API_KEY" + weight: 1 + xai: + keys: + - name: "xai-main" + value: "env.XAI_API_KEY" + weight: 1 + cerebras: + keys: + - name: "cerebras-main" + value: "env.CEREBRAS_API_KEY" + weight: 1 + openrouter: + keys: + - name: "openrouter-main" + value: "env.OPENROUTER_API_KEY" + weight: 1 + nebius: + keys: + - name: "nebius-main" + value: "env.NEBIUS_API_KEY" + weight: 1 +``` + + + + +**Install command (any of the above)** + +```bash +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + -f provider-values.yaml +``` + + + + + +### Self-Hosted Providers + +Self-hosted providers point to a URL you operate. No API key is typically required (`value: ""`). + + + + +```yaml +# ollama-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + ollama: + keys: + - name: "ollama-local" + value: "" + weight: 1 + models: ["*"] + ollama_key_config: + url: "http://ollama.default.svc.cluster.local:11434" +``` + +```bash +helm install bifrost bifrost/bifrost -f ollama-values.yaml +``` + +Using an env var for the URL (useful across environments): + +```bash +kubectl create secret generic ollama-config \ + --from-literal=url='http://ollama.default.svc.cluster.local:11434' +``` + +```yaml + ollama_key_config: + url: "env.OLLAMA_URL" + + providerSecrets: + ollama-url: + existingSecret: "ollama-config" + key: "url" + envVar: "OLLAMA_URL" +``` + + + + +vLLM instances are model-specific — one key per served model. + +```yaml +# vllm-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + vllm: + keys: + - name: "vllm-llama3-70b" + value: "" + weight: 1 + models: ["llama-3-70b"] + vllm_key_config: + url: "http://vllm.default.svc.cluster.local:8000" + model_name: "meta-llama/Meta-Llama-3-70B-Instruct" + - name: "vllm-mistral" + value: "" + weight: 1 + models: ["mistral-7b"] + vllm_key_config: + url: "http://vllm-mistral.default.svc.cluster.local:8000" + model_name: "mistralai/Mistral-7B-Instruct-v0.3" +``` + +```bash +helm install bifrost bifrost/bifrost -f vllm-values.yaml +``` + + + + +```yaml +# sgl-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + sgl: + keys: + - name: "sgl-main" + value: "" + weight: 1 + models: ["*"] + sgl_key_config: + url: "http://sgl-router.default.svc.cluster.local:30000" +``` + +```bash +helm install bifrost bifrost/bifrost -f sgl-values.yaml +``` + + + + +These providers use `aliases` to map logical model names to provider-specific IDs. + +```yaml +bifrost: + providers: + huggingface: + keys: + - name: "hf-main" + value: "env.HF_API_KEY" + weight: 1 + models: ["llama-3", "mixtral"] + aliases: + llama-3: "meta-llama/Meta-Llama-3-8B-Instruct" + mixtral: "mistralai/Mixtral-8x7B-Instruct-v0.1" + + replicate: + keys: + - name: "replicate-main" + value: "env.REPLICATE_API_KEY" + weight: 1 + models: ["llama-3"] + aliases: + llama-3: "meta/meta-llama-3-70b-instruct" + replicate_key_config: + use_deployments_endpoint: false +``` + + + + + + + + +--- + +## Multi-Provider Example + +Combine providers in a single values file: + +```yaml +# multi-provider-values.yaml +image: + tag: "v1.4.11" + +bifrost: + providers: + openai: + keys: + - name: "openai-primary" + value: "env.OPENAI_API_KEY" + weight: 2 + models: ["*"] + anthropic: + keys: + - name: "anthropic-primary" + value: "env.ANTHROPIC_API_KEY" + weight: 1 + models: ["*"] + groq: + keys: + - name: "groq-primary" + value: "env.GROQ_API_KEY" + weight: 1 + models: ["*"] + + providerSecrets: + openai-key: + existingSecret: "provider-keys" + key: "openai" + envVar: "OPENAI_API_KEY" + anthropic-key: + existingSecret: "provider-keys" + key: "anthropic" + envVar: "ANTHROPIC_API_KEY" + groq-key: + existingSecret: "provider-keys" + key: "groq" + envVar: "GROQ_API_KEY" + + plugins: + logging: + enabled: true + governance: + enabled: true +``` + +```bash +# Create a single secret with all provider keys +kubectl create secret generic provider-keys \ + --from-literal=openai='sk-your-openai-key' \ + --from-literal=anthropic='sk-ant-your-anthropic-key' \ + --from-literal=groq='gsk_your-groq-key' + +helm install bifrost bifrost/bifrost -f multi-provider-values.yaml +``` diff --git a/docs/deployment-guides/helm/storage.mdx b/docs/deployment-guides/helm/storage.mdx new file mode 100644 index 0000000000..244ece3fb2 --- /dev/null +++ b/docs/deployment-guides/helm/storage.mdx @@ -0,0 +1,550 @@ +--- +title: "Storage" +description: "Configure Bifrost storage backends in Helm — SQLite, PostgreSQL (embedded and external), per-store overrides, and S3/GCS object storage for logs" +icon: "database" +--- + +Bifrost persists two types of data — **config** (providers, virtual keys, governance rules) and **logs** (request/response records). Each has its own store, both defaulting to the top-level `storage.mode`. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `storage.mode` | Default backend for both stores (`sqlite` or `postgres`) | `sqlite` | +| `storage.configStore.type` | Override backend for the config store | `""` (inherits `storage.mode`) | +| `storage.logsStore.type` | Override backend for the logs store | `""` (inherits `storage.mode`) | + + +When any store uses SQLite the chart deploys a **StatefulSet** with a PVC. With PostgreSQL only (no SQLite) it deploys a **Deployment**. Mixing backends (e.g. config=postgres, logs=sqlite) still requires a StatefulSet. + + +--- + + + + + +### SQLite (Default) + +Simplest setup — no external database required. Bifrost runs as a StatefulSet with a persistent volume for the SQLite files. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `storage.persistence.enabled` | Create a PVC for SQLite data | `true` | +| `storage.persistence.size` | PVC size | `10Gi` | +| `storage.persistence.accessMode` | PVC access mode | `ReadWriteOnce` | +| `storage.persistence.storageClass` | Storage class (leave empty for cluster default) | `""` | +| `storage.persistence.existingClaim` | Reuse an existing PVC | `""` | + +```yaml +# sqlite-values.yaml +image: + tag: "v1.4.11" + +storage: + mode: sqlite + persistence: + enabled: true + size: 20Gi + # storageClass: "gp3" # uncomment to pin storage class + +bifrost: + encryptionKey: "your-32-byte-encryption-key-here" +``` + +```bash +helm install bifrost bifrost/bifrost -f sqlite-values.yaml +``` + +**Reuse an existing PVC** (e.g. after a StatefulSet migration): + +```yaml +storage: + persistence: + existingClaim: "bifrost-data" +``` + + +Upgrading from SQLite to PostgreSQL requires a data migration — the two stores are not compatible. Plan accordingly before switching `storage.mode` on a running deployment. + + +#### StatefulSet Migration (chart v2.0.0+) + +Prior to v2.0.0, SQLite used a Deployment + manual PVC. v2.0.0 moved SQLite to a StatefulSet. If upgrading from an older chart: + +```bash +# 1. Scale down the old deployment +kubectl scale deployment bifrost --replicas=0 + +# 2. Note the existing PVC name +kubectl get pvc + +# 3. Upgrade the chart, pointing at the existing claim +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set storage.persistence.existingClaim= \ + --set image.tag=v1.4.11 +``` + + + + + +### Embedded PostgreSQL + +The chart can deploy a PostgreSQL instance alongside Bifrost. Good for simple production setups where you don't have an existing database. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `storage.mode` | Set to `postgres` | `sqlite` | +| `postgresql.enabled` | Deploy PostgreSQL as a sub-deployment | `false` | +| `postgresql.auth.username` | Database user | `bifrost` | +| `postgresql.auth.password` | Database password | `bifrost_password` | +| `postgresql.auth.database` | Database name | `bifrost` | +| `postgresql.primary.persistence.size` | PVC size for PostgreSQL data | `8Gi` | + + +Ensure the database is created with **UTF8 encoding**. The embedded PostgreSQL deployment handles this automatically. See [PostgreSQL UTF8 Requirement](/quickstart/gateway/setting-up#postgresql-utf8-requirement) for manual setups. + + +```bash +kubectl create secret generic postgres-credentials \ + --from-literal=password='your-secure-postgres-password' +``` + +```yaml +# embedded-postgres-values.yaml +image: + tag: "v1.4.11" + +storage: + mode: postgres + +postgresql: + enabled: true + auth: + username: bifrost + password: "your-secure-postgres-password" # use existingSecret in production + database: bifrost + primary: + persistence: + enabled: true + size: 50Gi + resources: + requests: + cpu: 500m + memory: 1Gi + limits: + cpu: 2000m + memory: 4Gi + +bifrost: + encryptionKey: "your-32-byte-encryption-key-here" +``` + +```bash +helm install bifrost bifrost/bifrost -f embedded-postgres-values.yaml +``` + +**Verify the connection from Bifrost:** + +```bash +kubectl exec -it deployment/bifrost -- nc -zv bifrost-postgresql 5432 +``` + + + + + +### External PostgreSQL + +Point Bifrost at an existing PostgreSQL instance — RDS, Cloud SQL, Azure Database, or self-managed. + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `postgresql.enabled` | Must be `false` | `false` | +| `postgresql.external.enabled` | Enable external connection | `false` | +| `postgresql.external.host` | Hostname or IP | `""` | +| `postgresql.external.port` | Port | `5432` | +| `postgresql.external.user` | Username | `bifrost` | +| `postgresql.external.database` | Database name | `bifrost` | +| `postgresql.external.sslMode` | SSL mode (`disable`, `require`, `verify-ca`, `verify-full`) | `disable` | +| `postgresql.external.existingSecret` | Secret name for the password | `""` | +| `postgresql.external.passwordKey` | Key within the secret | `"password"` | + +```bash +kubectl create secret generic external-postgres-credentials \ + --from-literal=password='your-external-postgres-password' +``` + +```yaml +# external-postgres-values.yaml +image: + tag: "v1.4.11" + +storage: + mode: postgres + +postgresql: + enabled: false + external: + enabled: true + host: "your-rds-endpoint.us-east-1.rds.amazonaws.com" + port: 5432 + user: bifrost + database: bifrost + sslMode: require + existingSecret: "external-postgres-credentials" + passwordKey: "password" + +bifrost: + encryptionKey: "your-32-byte-encryption-key-here" +``` + +```bash +helm install bifrost bifrost/bifrost -f external-postgres-values.yaml +``` + +**Test connectivity before installing:** + +```bash +kubectl run pg-test --image=postgres:16-alpine --rm -it --restart=Never -- \ + psql "host=your-rds-endpoint.us-east-1.rds.amazonaws.com dbname=bifrost user=bifrost sslmode=require" \ + -c "SELECT version();" +``` + + + + + +### Mixed Backend + +Run the config store on PostgreSQL (fast lookups, shared across replicas) while keeping logs on SQLite (simpler, cheaper for append-heavy workloads). + +```yaml +# mixed-values.yaml +image: + tag: "v1.4.11" + +storage: + mode: sqlite # default fallback + configStore: + type: postgres # override: config uses postgres + logsStore: + type: sqlite # explicit: logs use sqlite + persistence: + enabled: true + size: 20Gi # for the SQLite logs store + +postgresql: + external: + enabled: true + host: "your-postgres-host.example.com" + port: 5432 + user: bifrost + database: bifrost + sslMode: require + existingSecret: "postgres-credentials" + passwordKey: "password" + +bifrost: + encryptionKey: "your-32-byte-encryption-key-here" +``` + +```bash +kubectl create secret generic postgres-credentials \ + --from-literal=password='your-postgres-password' + +helm install bifrost bifrost/bifrost -f mixed-values.yaml +``` + + +In mixed mode, Bifrost deploys a StatefulSet (because SQLite is in use) with both a PostgreSQL connection and a local PVC for the SQLite log store. + + +**PostgreSQL connection pool tuning** (high log volume): + +```yaml +storage: + configStore: + type: postgres + maxIdleConns: 5 + maxOpenConns: 50 + logsStore: + type: postgres + maxIdleConns: 10 + maxOpenConns: 100 +``` + + + + + +--- + +## Object Storage for Logs + +Offload large request/response payloads from the database to S3 or GCS. The DB retains only lightweight index records; payloads are fetched on demand. + + + + +```bash +kubectl create secret generic s3-credentials \ + --from-literal=access-key-id='AKIAIOSFODNN7EXAMPLE' \ + --from-literal=secret-access-key='wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' +``` + +```yaml +storage: + logsStore: + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + prefix: "bifrost" + compress: true # gzip compression + + # S3 configuration + region: us-east-1 + accessKeyId: "env.S3_ACCESS_KEY_ID" + secretAccessKey: "env.S3_SECRET_ACCESS_KEY" + # endpoint: "" # Custom endpoint for MinIO / Cloudflare R2 + # forcePathStyle: false # Set true for MinIO + +bifrost: + # inject S3 credentials as env vars + providerSecrets: + s3-access-key: + existingSecret: "s3-credentials" + key: "access-key-id" + envVar: "S3_ACCESS_KEY_ID" + s3-secret-key: + existingSecret: "s3-credentials" + key: "secret-access-key" + envVar: "S3_SECRET_ACCESS_KEY" +``` + +**Using IAM role (IRSA / instance profile) instead of static keys:** + +```yaml +storage: + logsStore: + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + region: us-east-1 + # No accessKeyId / secretAccessKey — uses SDK default chain + roleArn: "arn:aws:iam::123456789012:role/BifrostS3Role" +``` + + + + +```bash +kubectl create secret generic gcs-credentials \ + --from-literal=service-account-json="$(cat service-account-key.json)" +``` + +```yaml +storage: + logsStore: + objectStorage: + enabled: true + type: gcs + bucket: "bifrost-logs" + prefix: "bifrost" + compress: true + + # GCS configuration + projectId: "my-gcp-project" + credentialsJson: "env.GCS_CREDENTIALS_JSON" # omit for Workload Identity + +bifrost: + providerSecrets: + gcs-creds: + existingSecret: "gcs-credentials" + key: "service-account-json" + envVar: "GCS_CREDENTIALS_JSON" +``` + + + + +```yaml +storage: + logsStore: + objectStorage: + enabled: true + type: s3 + bucket: "bifrost-logs" + prefix: "bifrost" + compress: false + + region: us-east-1 # can be any value for MinIO + endpoint: "http://minio.minio-ns.svc.cluster.local:9000" + accessKeyId: "env.MINIO_ACCESS_KEY" + secretAccessKey: "env.MINIO_SECRET_KEY" + forcePathStyle: true # required for MinIO +``` + + + + +```bash +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + -f object-storage-values.yaml +``` + +--- + +## Vector Store + +A vector store is required for [semantic caching](/deployment-guides/helm/plugins). Choose from Weaviate, Redis, or Qdrant (embedded or external), or Pinecone (external only). + + + + +```yaml +vectorStore: + enabled: true + type: weaviate + weaviate: + enabled: true # deploy embedded Weaviate + replicas: 1 + persistence: + enabled: true + size: 20Gi + resources: + requests: + cpu: 500m + memory: 1Gi + limits: + cpu: 2000m + memory: 4Gi +``` + +**External Weaviate:** + +```yaml +vectorStore: + enabled: true + type: weaviate + weaviate: + enabled: false + external: + enabled: true + scheme: https + host: "weaviate.example.com" + apiKey: "env.WEAVIATE_API_KEY" + grpcHost: "weaviate-grpc.example.com" + grpcSecured: true + existingSecret: "weaviate-credentials" + apiKeyKey: "api-key" +``` + + + + +```yaml +vectorStore: + enabled: true + type: redis + redis: + enabled: true # deploy embedded Redis + auth: + enabled: true + password: "redis_password" + master: + persistence: + size: 8Gi +``` + +**External Redis / AWS MemoryDB:** + +```bash +kubectl create secret generic redis-credentials \ + --from-literal=password='your-redis-password' +``` + +```yaml +vectorStore: + enabled: true + type: redis + redis: + enabled: false + external: + enabled: true + host: "your-redis.cache.amazonaws.com" + port: 6379 + useTls: true + clusterMode: true # required for AWS MemoryDB + existingSecret: "redis-credentials" + passwordKey: "password" +``` + + + + +```yaml +vectorStore: + enabled: true + type: qdrant + qdrant: + enabled: true # deploy embedded Qdrant + persistence: + size: 10Gi +``` + +**External Qdrant:** + +```bash +kubectl create secret generic qdrant-credentials \ + --from-literal=api-key='your-qdrant-api-key' +``` + +```yaml +vectorStore: + enabled: true + type: qdrant + qdrant: + enabled: false + external: + enabled: true + host: "qdrant.example.com" + port: 6334 + useTls: true + existingSecret: "qdrant-credentials" + apiKeyKey: "api-key" +``` + + + + +Pinecone is external-only. + +```bash +kubectl create secret generic pinecone-credentials \ + --from-literal=api-key='your-pinecone-api-key' +``` + +```yaml +vectorStore: + enabled: true + type: pinecone + pinecone: + external: + enabled: true + indexHost: "your-index.svc.us-east1-gcp.pinecone.io" + existingSecret: "pinecone-credentials" + apiKeyKey: "api-key" +``` + + + + +```bash +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + -f storage-values.yaml +``` diff --git a/docs/deployment-guides/helm/troubleshooting.mdx b/docs/deployment-guides/helm/troubleshooting.mdx new file mode 100644 index 0000000000..1a46d0219d --- /dev/null +++ b/docs/deployment-guides/helm/troubleshooting.mdx @@ -0,0 +1,401 @@ +--- +title: "Troubleshooting" +description: "Diagnose and fix common issues with Bifrost Helm deployments — pods, database, ingress, secrets, PVCs, and performance" +icon: "wrench" +--- + +This page covers the most common problems encountered when deploying Bifrost with Helm, along with diagnostic commands and fixes. + +--- + +## Pod Not Starting + +### Quick diagnostics + +```bash +# Show pod status +kubectl get pods -l app.kubernetes.io/name=bifrost + +# Show pod events (most useful first step) +kubectl describe pod -l app.kubernetes.io/name=bifrost + +# Show pod logs (use --previous if the pod has already crashed) +kubectl logs -l app.kubernetes.io/name=bifrost +kubectl logs -l app.kubernetes.io/name=bifrost --previous +``` + +### Image pull errors (`ErrImagePull` / `ImagePullBackOff`) + +```bash +# Check which image is being pulled +kubectl describe pod -l app.kubernetes.io/name=bifrost | grep "Image:" + +# Verify imagePullSecrets are attached +kubectl get pod -l app.kubernetes.io/name=bifrost -o jsonpath='{.items[0].spec.imagePullSecrets}' + +# Test secret manually +kubectl get secret -o jsonpath='{.data.\.dockerconfigjson}' | base64 -d | jq . +``` + +Common causes: +- `image.tag` not set — the chart requires it; the pod will not start without it +- Pull secret missing or expired (ECR tokens expire after 12 hours) +- Incorrect `image.repository` for enterprise registry + +```bash +# Fix: set the correct tag +helm upgrade bifrost bifrost/bifrost --reuse-values --set image.tag=v1.4.11 +``` + +### PVC not binding (`Pending`) + +```bash +# Check PVC status +kubectl get pvc -l app.kubernetes.io/instance=bifrost + +# Show binding events +kubectl describe pvc -l app.kubernetes.io/instance=bifrost +``` + +Common causes: +- No Persistent Volume provisioner in the cluster +- `storageClass` set to a class that doesn't exist +- `ReadWriteOnce` access mode with multiple replicas (SQLite PVCs are single-node) + +```bash +# List available storage classes +kubectl get storageclass + +# Fix: pin to a valid storage class +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set storage.persistence.storageClass=standard +``` + +### ConfigMap / Secret errors + +```bash +# View the generated ConfigMap (contains rendered config.json) +kubectl get configmap bifrost-config -o yaml + +# View secrets the pod depends on +kubectl get secret -l app.kubernetes.io/instance=bifrost + +# Decode a specific secret value +kubectl get secret bifrost-encryption -o jsonpath='{.data.key}' | base64 -d +``` + +### CrashLoopBackOff + +```bash +# Get last log lines before the crash +kubectl logs -l app.kubernetes.io/name=bifrost --previous --tail=50 + +# Common causes shown in logs: +# "encryption key is required" → bifrost.encryptionKey or encryptionKeySecret not set +# "failed to connect to database" → see Database section below +# "image.tag is required" → set image.tag in values +``` + +--- + +## Database Connection Issues + +### Embedded PostgreSQL + +```bash +# Check if the PostgreSQL pod is running +kubectl get pods -l app.kubernetes.io/name=bifrost-postgresql + +# Connect directly to inspect the database +kubectl exec -it deployment/bifrost-postgresql -- psql -U bifrost -d bifrost + +# Test connectivity from the Bifrost pod +kubectl exec -it deployment/bifrost -- nc -zv bifrost-postgresql 5432 + +# Check PostgreSQL logs +kubectl logs deployment/bifrost-postgresql --tail=50 +``` + +### External PostgreSQL + +```bash +# Test connectivity from within the cluster +kubectl run pg-test --image=postgres:16-alpine --rm -it --restart=Never -- \ + psql "host=your-db-host dbname=bifrost user=bifrost sslmode=require" + +# Verify the secret value is correct +kubectl get secret postgres-credentials -o jsonpath='{.data.password}' | base64 -d + +# Check that the external host/port is reachable +kubectl exec -it deployment/bifrost -- nc -zv your-db-host 5432 +``` + +Common causes: +- `sslMode: disable` when the database requires SSL — set `sslMode: require` +- Password in secret doesn't match the database user +- Network policy blocking pod → database traffic +- Database not UTF8 encoded (see [PostgreSQL UTF8 Requirement](/quickstart/gateway/setting-up#postgresql-utf8-requirement)) + +```bash +# Fix: update the secret and restart +kubectl create secret generic postgres-credentials \ + --from-literal=password='correct-password' \ + --dry-run=client -o yaml | kubectl apply -f - + +kubectl rollout restart deployment/bifrost +``` + +--- + +## Ingress Not Working + +```bash +# Check ingress resource status +kubectl describe ingress bifrost + +# Check if the ingress controller is running +kubectl get pods -n ingress-nginx -l app.kubernetes.io/name=ingress-nginx + +# View ingress controller logs for routing errors +kubectl logs -n ingress-nginx -l app.kubernetes.io/name=ingress-nginx --tail=50 + +# Verify DNS resolves to the correct load balancer IP +nslookup bifrost.yourdomain.com +kubectl get ingress bifrost -o jsonpath='{.status.loadBalancer.ingress[0].ip}' + +# Test without TLS first +curl -v http://bifrost.yourdomain.com/health +``` + +Common causes: +- `ingress.className` not set or set to a class not installed in the cluster +- TLS certificate not issued yet (cert-manager can take up to 60 seconds) +- Service port mismatch — Bifrost listens on `8080` by default + +```bash +# Check cert-manager certificate status +kubectl get certificate -l app.kubernetes.io/instance=bifrost +kubectl describe certificate bifrost-tls +``` + +--- + +## Secret and Credential Issues + +### Provider API key not resolving + +If Bifrost logs show `env.OPENAI_API_KEY: not set` or similar: + +```bash +# Check the env var is present in the running pod +kubectl exec -it deployment/bifrost -- env | grep OPENAI + +# Verify the providerSecrets secret exists with the right key +kubectl get secret provider-api-keys -o yaml + +# Check the providerSecrets configuration rendered correctly +kubectl get configmap bifrost-config -o yaml | grep -A5 providers +``` + +### Encryption key issues + +```bash +# Verify the secret exists and contains the right key name +kubectl get secret bifrost-encryption -o yaml + +# Check the exact key name matches encryptionKeySecret.key in values +# Default key name is "encryption-key" — if you used "key", set: +# bifrost.encryptionKeySecret.key: "key" +``` + +--- + +## High Memory Usage + +```bash +# Check current resource usage +kubectl top pods -l app.kubernetes.io/name=bifrost + +# Check if OOM kills are happening +kubectl describe pod -l app.kubernetes.io/name=bifrost | grep -A3 "OOMKilled\|Limits" + +# View resource requests/limits on running pods +kubectl get pod -l app.kubernetes.io/name=bifrost \ + -o jsonpath='{range .items[*]}{.metadata.name}{"\t"}{.spec.containers[0].resources}{"\n"}{end}' +``` + +**Increase resource limits:** + +```bash +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set resources.limits.memory=4Gi \ + --set resources.requests.memory=1Gi +``` + +**Tune Go runtime** (see [Docker Tuning](/deployment-guides/docker-tuning)): + +```yaml +env: + - name: GOGC + value: "200" # run GC less often + - name: GOMEMLIMIT + value: "3500MiB" # hard memory ceiling slightly below the container limit +``` + +--- + +## High CPU Usage / Latency + +```bash +# Check CPU usage +kubectl top pods -l app.kubernetes.io/name=bifrost + +# Check if HPA is scaling correctly +kubectl get hpa bifrost +kubectl describe hpa bifrost +``` + +Common causes: +- `initialPoolSize` too small — goroutines queuing up; increase to `500`–`1000` +- `dropExcessRequests: false` with a small pool — queue depth growing unboundedly + +```bash +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set bifrost.client.initialPoolSize=1000 \ + --set bifrost.client.dropExcessRequests=true +``` + +--- + +## Autoscaling Issues + +### HPA not scaling + +```bash +# Check HPA status and current metrics +kubectl describe hpa bifrost + +# Verify metrics server is installed +kubectl top nodes +kubectl top pods + +# Common fix: metrics server not installed +# Install with: +kubectl apply -f https://github.com/kubernetes-sigs/metrics-server/releases/latest/download/components.yaml +``` + +### Pods scaling down too aggressively (drops active SSE streams) + +The default `scaleDown.stabilizationWindowSeconds: 300` and `preStop` sleep of 15 seconds should prevent this. If streams are still being cut: + +```yaml +terminationGracePeriodSeconds: 120 # increase if streams run longer than 105s + +autoscaling: + behavior: + scaleDown: + stabilizationWindowSeconds: 600 # wait 10 min before scaling down + policies: + - type: Pods + value: 1 + periodSeconds: 300 # remove at most 1 pod per 5 min + +lifecycle: + preStop: + exec: + command: ["sh", "-c", "sleep 30"] # give load balancer more time to drain +``` + +```bash +helm upgrade bifrost bifrost/bifrost --reuse-values -f graceful-shutdown-values.yaml +``` + +--- + +## SQLite / PVC Issues + +### StatefulSet migration (upgrading from chart < v2.0.0) + +Older chart versions used a Deployment + manual PVC. v2.0.0 moved SQLite to a StatefulSet. If upgrading: + +```bash +# 1. Scale down the old deployment +kubectl scale deployment bifrost --replicas=0 + +# 2. Note the existing PVC name +kubectl get pvc + +# 3. Upgrade, pointing at the existing claim +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set storage.persistence.existingClaim= \ + --set image.tag=v1.4.11 +``` + +### Data lost after upgrade + +```bash +# Check if PVCs still exist (they persist after helm uninstall) +kubectl get pvc -l app.kubernetes.io/instance=bifrost + +# Re-attach by setting existingClaim +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set storage.persistence.existingClaim= +``` + +--- + +## Cluster Mode Issues + +### Peers not discovering each other + +```bash +# Check gossip port is reachable between pods +kubectl exec -it bifrost-0 -- nc -zv bifrost-1.bifrost-headless 7946 + +# View gossip-related log lines +kubectl logs -l app.kubernetes.io/name=bifrost --tail=100 | grep -i gossip + +# Check the headless service exists +kubectl get svc bifrost-headless +``` + +For Kubernetes-based discovery, verify the service account has pod list permissions: + +```bash +kubectl auth can-i list pods --as=system:serviceaccount:default:bifrost +``` + +--- + +## Useful Diagnostic Commands + +```bash +# Full state dump for a support ticket +kubectl get all -l app.kubernetes.io/instance=bifrost +kubectl describe pod -l app.kubernetes.io/name=bifrost > pod-describe.txt +kubectl logs -l app.kubernetes.io/name=bifrost --tail=200 > pod-logs.txt + +# View the full rendered config.json +kubectl get configmap bifrost-config -o jsonpath='{.data.config\.json}' | jq . + +# Check current Helm values (shows all overrides) +helm get values bifrost + +# Check Helm release status +helm status bifrost + +# View Helm release history +helm history bifrost +``` + +--- + +## Still Stuck? + +- [GitHub Issues](https://github.com/maximhq/bifrost/issues) — search existing issues or open a new one +- [Enterprise Support](mailto:support@getmaxim.ai) — for enterprise customers with SLA diff --git a/docs/deployment-guides/helm/values.mdx b/docs/deployment-guides/helm/values.mdx new file mode 100644 index 0000000000..3161b206fb --- /dev/null +++ b/docs/deployment-guides/helm/values.mdx @@ -0,0 +1,718 @@ +--- +title: "Values Reference" +description: "Complete reference for Bifrost Helm chart values — key parameters, how to supply them, and links to example files" +icon: "sliders" +--- + +This page covers every top-level parameter group in the Bifrost Helm chart's `values.yaml`, how to supply values via `--set` vs `-f`, and where to find ready-made example files. + + +The full values schema is available at [https://getbifrost.ai/schema](https://getbifrost.ai/schema). All `values.yaml` fields map directly to `config.json` fields generated by the chart. + + +## Supplying Values + +### One-liner with `--set` + +Good for a single field or quick experiments: + +```bash +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + --set replicaCount=3 \ + --set bifrost.client.initialPoolSize=500 +``` + +### Values file with `-f` + +Recommended for anything beyond a couple of fields: + +```bash +# Create your values file +cat > my-values.yaml <<'EOF' +image: + tag: "v1.4.11" + +replicaCount: 2 + +bifrost: + encryptionKey: "your-32-byte-encryption-key-here" + client: + initialPoolSize: 500 + enableLogging: true +EOF + +# Install +helm install bifrost bifrost/bifrost -f my-values.yaml + +# Upgrade later +helm upgrade bifrost bifrost/bifrost -f my-values.yaml + +# Upgrade and reuse all previously set values, overriding only one field +helm upgrade bifrost bifrost/bifrost \ + --reuse-values \ + --set replicaCount=5 +``` + +### Multiple values files + +Later files override earlier ones — useful for a base + environment-specific overlay: + +```bash +helm install bifrost bifrost/bifrost \ + -f base-values.yaml \ + -f production-overrides.yaml +``` + +--- + +## Key Parameters Reference + +### Image + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `image.repository` | Container image repository | `docker.io/maximhq/bifrost` | +| `image.tag` | **Required.** Image version (e.g. `v1.4.11`) | `""` | +| `image.pullPolicy` | Image pull policy | `IfNotPresent` | +| `imagePullSecrets` | List of pull secret names for private registries | `[]` | + +```bash +# Always specify the tag — the chart will not start without it +helm install bifrost bifrost/bifrost --set image.tag=v1.4.11 +``` + +### Replicas & Autoscaling + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `replicaCount` | Static replica count (ignored when HPA is enabled) | `1` | +| `autoscaling.enabled` | Enable Horizontal Pod Autoscaler | `false` | +| `autoscaling.minReplicas` | Minimum replicas | `1` | +| `autoscaling.maxReplicas` | Maximum replicas | `10` | +| `autoscaling.targetCPUUtilizationPercentage` | CPU target for scaling | `80` | +| `autoscaling.targetMemoryUtilizationPercentage` | Memory target for scaling | `80` | +| `autoscaling.behavior.scaleDown.stabilizationWindowSeconds` | Cooldown before scale-down (important for SSE streams) | `300` | +| `autoscaling.behavior.scaleDown.policies[0].value` | Max pods removed per period | `1` | + +### Resources + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `resources.requests.cpu` | CPU request | `500m` | +| `resources.requests.memory` | Memory request | `512Mi` | +| `resources.limits.cpu` | CPU limit | `2000m` | +| `resources.limits.memory` | Memory limit | `2Gi` | + +### Service + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `service.type` | `ClusterIP`, `LoadBalancer`, or `NodePort` | `ClusterIP` | +| `service.port` | Service port | `8080` | + +### Ingress + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `ingress.enabled` | Enable ingress | `false` | +| `ingress.className` | Ingress class (e.g. `nginx`, `traefik`) | `""` | +| `ingress.annotations` | Ingress annotations | `{}` | +| `ingress.hosts` | Host rules | see values.yaml | +| `ingress.tls` | TLS configuration | `[]` | + +```yaml +ingress: + enabled: true + className: nginx + annotations: + cert-manager.io/cluster-issuer: letsencrypt-prod + nginx.ingress.kubernetes.io/proxy-body-size: "100m" + hosts: + - host: bifrost.yourdomain.com + paths: + - path: / + pathType: Prefix + tls: + - secretName: bifrost-tls + hosts: + - bifrost.yourdomain.com +``` + +### Probes + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `livenessProbe.initialDelaySeconds` | Seconds before first liveness check | `30` | +| `livenessProbe.periodSeconds` | Liveness check interval | `30` | +| `readinessProbe.initialDelaySeconds` | Seconds before first readiness check | `10` | +| `readinessProbe.periodSeconds` | Readiness check interval | `10` | + +Both probes hit `GET /health`. + +### Graceful Shutdown + +Bifrost supports long-lived SSE streaming connections. The default `preStop` hook and termination grace period let in-flight streams finish before the pod is killed: + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `terminationGracePeriodSeconds` | Total grace period | `60` | +| `lifecycle.preStop.exec.command` | Sleep before SIGTERM so load balancer drains | `["sh", "-c", "sleep 15"]` | + +Increase `terminationGracePeriodSeconds` if your typical stream responses take longer than 45 seconds. + +### Service Account + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `serviceAccount.create` | Create a dedicated service account | `true` | +| `serviceAccount.annotations` | Annotations (e.g. for IRSA, Workload Identity) | `{}` | +| `serviceAccount.name` | Override the generated name | `""` | + +### Pod Scheduling + +```yaml +# Spread replicas across nodes +affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchLabels: + app.kubernetes.io/name: bifrost + topologyKey: kubernetes.io/hostname + +# Pin to specific node pool +nodeSelector: + node-type: ai-workload + +# Tolerate GPU taints +tolerations: + - key: "gpu" + operator: "Equal" + value: "true" + effect: "NoSchedule" +``` + +### Extra Environment Variables + +Three ways to inject env vars: + +```yaml +# Inline key/value pairs +env: + - name: HTTP_PROXY + value: "http://proxy.corp.example.com:3128" + +# Map syntax (appended after env) +extraEnv: + NO_PROXY: "169.254.169.254,10.0.0.0/8" + +# Bulk-load from existing Secrets or ConfigMaps +envFrom: + - secretRef: + name: my-corp-secrets + - configMapRef: + name: my-app-config +``` + +### Init Containers + +```yaml +initContainers: + - name: wait-for-db + image: busybox:1.35 + command: ["sh", "-c", "until nc -z postgres-svc 5432; do sleep 2; done"] +``` + +--- + +## Values Examples + +The chart ships ready-made example files under [`helm-charts/bifrost/values-examples/`](https://github.com/maximhq/bifrost/tree/main/helm-charts/bifrost/values-examples): + +| File | Use case | +|------|----------| +| `sqlite-only.yaml` | Minimal local/dev setup | +| `postgres-only.yaml` | Single-store Postgres | +| `production-ha.yaml` | HA: 3 replicas, Postgres, Weaviate, HPA, Ingress | +| `providers-and-virtual-keys.yaml` | All 23 providers + 7 virtual key patterns | +| `secrets-from-k8s.yaml` | All sensitive values from Kubernetes Secrets | +| `external-postgres.yaml` | Point at an existing Postgres instance | +| `postgres-redis.yaml` | Postgres + Redis vector store | +| `postgres-weaviate.yaml` | Postgres + Weaviate vector store | +| `postgres-qdrant.yaml` | Postgres + Qdrant vector store | +| `semantic-cache-secret-example.yaml` | Semantic cache with secret injection | +| `mixed-backend.yaml` | Config store = postgres, logs store = sqlite | + +Install from an example file directly: + +```bash +helm install bifrost bifrost/bifrost \ + -f https://raw.githubusercontent.com/maximhq/bifrost/main/helm-charts/bifrost/values-examples/production-ha.yaml \ + --set image.tag=v1.4.11 +``` + +--- + +## Helm Operations + +### View current values + +```bash +helm get values bifrost +``` + +### Diff before upgrading (requires helm-diff plugin) + +```bash +helm diff upgrade bifrost bifrost/bifrost -f my-values.yaml +``` + +### Rollback + +```bash +helm history bifrost +helm rollback bifrost # to previous revision +helm rollback bifrost 2 # to revision 2 +``` + +### Uninstall + +```bash +helm uninstall bifrost + +# Also remove PVCs (deletes all data) +kubectl delete pvc -l app.kubernetes.io/instance=bifrost +``` + +--- + +## All Key Parameters + +A quick-reference table of the most commonly used top-level parameters: + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `image.tag` | **Required.** Bifrost image version (e.g., `v1.4.11`) | `""` | +| `replicaCount` | Number of replicas | `1` | +| `storage.mode` | Storage backend (`sqlite` or `postgres`) | `sqlite` | +| `storage.persistence.size` | PVC size for SQLite | `10Gi` | +| `postgresql.enabled` | Deploy embedded PostgreSQL | `false` | +| `vectorStore.enabled` | Enable vector store | `false` | +| `vectorStore.type` | Vector store type (`weaviate`, `redis`, `qdrant`) | `none` | +| `bifrost.encryptionKey` | Encryption key (use `encryptionKeySecret` in production) | `""` | +| `ingress.enabled` | Enable ingress | `false` | +| `autoscaling.enabled` | Enable HPA | `false` | + +### Secret Reference Parameters + +Use existing Kubernetes Secrets instead of plain-text values. Every sensitive field in the chart has a corresponding `existingSecret` / `secretRef` alternative: + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `bifrost.encryptionKeySecret.name` | Secret name for encryption key | `""` | +| `bifrost.encryptionKeySecret.key` | Key within the secret | `"encryption-key"` | +| `postgresql.external.existingSecret` | Secret name for PostgreSQL password | `""` | +| `postgresql.external.passwordKey` | Key within the secret | `"password"` | +| `vectorStore.redis.external.existingSecret` | Secret name for Redis password | `""` | +| `vectorStore.redis.external.passwordKey` | Key within the secret | `"password"` | +| `vectorStore.weaviate.external.existingSecret` | Secret name for Weaviate API key | `""` | +| `vectorStore.weaviate.external.apiKeyKey` | Key within the secret | `"api-key"` | +| `vectorStore.qdrant.external.existingSecret` | Secret name for Qdrant API key | `""` | +| `vectorStore.qdrant.external.apiKeyKey` | Key within the secret | `"api-key"` | +| `bifrost.plugins.maxim.secretRef.name` | Secret name for Maxim API key | `""` | +| `bifrost.plugins.maxim.secretRef.key` | Key within the secret | `"api-key"` | +| `bifrost.providerSecrets..existingSecret` | Secret name for provider API key | `""` | +| `bifrost.providerSecrets..key` | Key within the secret | `"api-key"` | +| `bifrost.providerSecrets..envVar` | Environment variable name to inject | `""` | + +--- + +## Advanced Configuration + +### Comprehensive Example + +A production-ready values file combining the most common settings: + +```yaml +# my-values.yaml +image: + tag: "v1.4.11" + +replicaCount: 3 + +storage: + mode: postgres + +postgresql: + enabled: true + auth: + password: "secure-password" # use existingSecret in production + +autoscaling: + enabled: true + minReplicas: 3 + maxReplicas: 10 + +ingress: + enabled: true + className: nginx + hosts: + - host: bifrost.example.com + paths: + - path: / + pathType: Prefix + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "key" + providers: + openai: + keys: + - name: "primary" + value: "env.OPENAI_API_KEY" + weight: 1 + providerSecrets: + openai: + existingSecret: "provider-api-keys" + key: "openai-api-key" + envVar: "OPENAI_API_KEY" +``` + +```bash +helm install bifrost bifrost/bifrost -f my-values.yaml +``` + +### Node Affinity & Scheduling + +Deploy to specific nodes and spread replicas across hosts: + +```yaml +nodeSelector: + node-type: ai-workload + +affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchLabels: + app.kubernetes.io/name: bifrost + topologyKey: kubernetes.io/hostname + +tolerations: + - key: "gpu" + operator: "Equal" + value: "true" + effect: "NoSchedule" +``` + +### Deployment & Pod Annotations + +Useful for tooling like [Keel](https://keel.sh) for automatic image updates or Datadog APM injection: + +```yaml +deploymentAnnotations: + keel.sh/policy: force + keel.sh/trigger: poll + +podAnnotations: + ad.datadoghq.com/bifrost.logs: '[{"source":"bifrost","service":"bifrost"}]' +``` + +--- + +## Common Patterns + +Ready-made values files for the most common deployment scenarios. Each pattern builds on the [quickstart](/deployment-guides/helm). + + + + +Simple setup for local testing. SQLite, single replica, no autoscaling. + +```bash +helm install bifrost bifrost/bifrost \ + --set image.tag=v1.4.11 \ + --set 'bifrost.providers.openai.keys[0].name=dev-key' \ + --set 'bifrost.providers.openai.keys[0].value=sk-your-key' \ + --set 'bifrost.providers.openai.keys[0].weight=1' +``` + +```bash +# Access +kubectl port-forward svc/bifrost 8080:8080 +``` + + + + +Multiple LLM providers with weighted load balancing. + +```bash +kubectl create secret generic provider-keys \ + --from-literal=openai-api-key='sk-...' \ + --from-literal=anthropic-api-key='sk-ant-...' \ + --from-literal=gemini-api-key='your-gemini-key' +``` + +```yaml +# multi-provider.yaml +image: + tag: "v1.4.11" + +bifrost: + encryptionKey: "your-encryption-key" + + client: + enableLogging: true + allowDirectKeys: false + + providers: + openai: + keys: + - name: "openai-primary" + value: "env.OPENAI_API_KEY" + weight: 2 # 50% of traffic + anthropic: + keys: + - name: "anthropic-primary" + value: "env.ANTHROPIC_API_KEY" + weight: 1 # 25% + gemini: + keys: + - name: "gemini-primary" + value: "env.GEMINI_API_KEY" + weight: 1 # 25% + + providerSecrets: + openai: + existingSecret: "provider-keys" + key: "openai-api-key" + envVar: "OPENAI_API_KEY" + anthropic: + existingSecret: "provider-keys" + key: "anthropic-api-key" + envVar: "ANTHROPIC_API_KEY" + gemini: + existingSecret: "provider-keys" + key: "gemini-api-key" + envVar: "GEMINI_API_KEY" + + plugins: + telemetry: + enabled: true + logging: + enabled: true +``` + +```bash +helm install bifrost bifrost/bifrost -f multi-provider.yaml +``` + + + + +Use an existing PostgreSQL instance — RDS, Cloud SQL, Azure Database, or self-managed. + +```bash +kubectl create secret generic postgres-credentials \ + --from-literal=password='your-external-postgres-password' +``` + +```yaml +# external-db.yaml +image: + tag: "v1.4.11" + +storage: + mode: postgres + +postgresql: + enabled: false + external: + enabled: true + host: "your-rds-endpoint.us-east-1.rds.amazonaws.com" + port: 5432 + user: "bifrost" + database: "bifrost" + sslMode: "require" + existingSecret: "postgres-credentials" + passwordKey: "password" + +bifrost: + encryptionKey: "your-encryption-key" + + providers: + openai: + keys: + - name: "openai-primary" + value: "sk-..." + weight: 1 +``` + +```bash +helm install bifrost bifrost/bifrost -f external-db.yaml +``` + + + + +Semantic response caching for high-volume AI inference. + +```bash +kubectl create secret generic bifrost-encryption \ + --from-literal=key='your-32-byte-encryption-key' + +kubectl create secret generic provider-keys \ + --from-literal=openai-api-key='sk-your-key' +``` + +```yaml +# ai-workload.yaml +image: + tag: "v1.4.11" + +storage: + mode: postgres + +postgresql: + enabled: true + auth: + password: "secure-password" + primary: + persistence: + size: 50Gi + +vectorStore: + enabled: true + type: weaviate + weaviate: + enabled: true + persistence: + size: 50Gi + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "key" + + providers: + openai: + keys: + - name: "openai-primary" + value: "env.OPENAI_API_KEY" + weight: 1 + + providerSecrets: + openai: + existingSecret: "provider-keys" + key: "openai-api-key" + envVar: "OPENAI_API_KEY" + + plugins: + semanticCache: + enabled: true + config: + provider: "openai" + keys: + - value: "env.OPENAI_API_KEY" + weight: 1 + embedding_model: "text-embedding-3-small" + dimension: 1536 + threshold: 0.85 + ttl: "1h" + cache_by_model: true + cache_by_provider: true +``` + +```bash +helm install bifrost bifrost/bifrost -f ai-workload.yaml +``` + + + + +Zero credentials in values files — all sensitive data in Kubernetes Secrets. + +```bash +kubectl create secret generic postgres-credentials \ + --from-literal=password='your-postgres-password' + +kubectl create secret generic bifrost-encryption \ + --from-literal=key='your-encryption-key' + +kubectl create secret generic provider-keys \ + --from-literal=openai-api-key='sk-...' \ + --from-literal=anthropic-api-key='sk-ant-...' + +kubectl create secret generic qdrant-credentials \ + --from-literal=api-key='your-qdrant-api-key' +``` + +```yaml +# secrets-only.yaml +image: + tag: "v1.4.11" + +storage: + mode: postgres + +postgresql: + enabled: false + external: + enabled: true + host: "postgres.example.com" + port: 5432 + user: "bifrost" + database: "bifrost" + sslMode: "require" + existingSecret: "postgres-credentials" + passwordKey: "password" + +vectorStore: + enabled: true + type: qdrant + qdrant: + enabled: false + external: + enabled: true + host: "qdrant.example.com" + port: 6334 + existingSecret: "qdrant-credentials" + apiKeyKey: "api-key" + +bifrost: + encryptionKeySecret: + name: "bifrost-encryption" + key: "key" + + providers: + openai: + keys: + - name: "openai-primary" + value: "env.OPENAI_API_KEY" + weight: 1 + anthropic: + keys: + - name: "anthropic-primary" + value: "env.ANTHROPIC_API_KEY" + weight: 1 + + providerSecrets: + openai: + existingSecret: "provider-keys" + key: "openai-api-key" + envVar: "OPENAI_API_KEY" + anthropic: + existingSecret: "provider-keys" + key: "anthropic-api-key" + envVar: "ANTHROPIC_API_KEY" +``` + +```bash +helm install bifrost bifrost/bifrost -f secrets-only.yaml +``` + + + diff --git a/docs/docs.json b/docs/docs.json index 155b91a096..b65a5c5936 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -60,7 +60,9 @@ { "group": "Overview", "icon": "book-open-cover", - "pages": ["overview"] + "pages": [ + "overview" + ] }, { "group": "Quick Start", @@ -100,12 +102,16 @@ { "group": "Migration Guides", "icon": "arrow-up-right-dots", - "pages": ["migration-guides/v1.5.0"] + "pages": [ + "migration-guides/v1.5.0" + ] }, { "group": "SDK Integrations", "icon": "plug", - "pages": ["integrations/what-is-an-integration"] + "pages": [ + "integrations/what-is-an-integration" + ] }, { "group": "Providers & Guides", @@ -178,7 +184,10 @@ { "group": "Writing Plugins", "icon": "code", - "pages": ["plugins/writing-go-plugin", "plugins/writing-wasm-plugin"] + "pages": [ + "plugins/writing-go-plugin", + "plugins/writing-wasm-plugin" + ] }, "plugins/migration-guide" ] @@ -188,7 +197,7 @@ "icon": "bolt", "pages": [ "features/drop-in-replacement", - "features/fallbacks", + "features/retries-and-fallbacks", "features/litellm-compat", "features/keys-management", "features/async-inference", @@ -208,12 +217,18 @@ { "group": "Prompt Repository", "icon": "folder", - "pages": ["features/prompt-repository/playground", "features/prompt-repository/prompts-plugin"] + "pages": [ + "features/prompt-repository/playground", + "features/prompt-repository/prompts-plugin" + ] }, { "group": "Plugins", "icon": "puzzle-piece", - "pages": ["features/plugins/mocker", "features/plugins/jsonparser"] + "pages": [ + "features/plugins/mocker", + "features/plugins/jsonparser" + ] } ] }, @@ -228,7 +243,10 @@ { "group": "Advanced Governance", "icon": "shield-check", - "pages": ["enterprise/advanced-governance", "enterprise/rbac"] + "pages": [ + "enterprise/advanced-governance", + "enterprise/rbac" + ] }, "enterprise/mcp-with-fa", "enterprise/invpc-deployments", @@ -268,22 +286,33 @@ { "group": "OpenAI SDK", "icon": "openai", - "pages": ["integrations/openai-sdk/overview", "integrations/openai-sdk/files-and-batch"] + "pages": [ + "integrations/openai-sdk/overview", + "integrations/openai-sdk/files-and-batch" + ] }, { "group": "Anthropic SDK", "icon": "asterisk", - "pages": ["integrations/anthropic-sdk/overview", "integrations/anthropic-sdk/files-and-batch"] + "pages": [ + "integrations/anthropic-sdk/overview", + "integrations/anthropic-sdk/files-and-batch" + ] }, { "group": "Bedrock SDK", "icon": "aws", - "pages": ["integrations/bedrock-sdk/overview", "integrations/bedrock-sdk/files-and-batch"] + "pages": [ + "integrations/bedrock-sdk/overview", + "integrations/bedrock-sdk/files-and-batch" + ] }, { "group": "GenAI SDK", "icon": "diamond", - "pages": ["integrations/genai-sdk/overview"] + "pages": [ + "integrations/genai-sdk/overview" + ] }, "integrations/litellm-sdk", "integrations/langchain-sdk", @@ -360,11 +389,42 @@ { "group": "Platform specific guides", "icon": "swatchbook", + "pages": ["deployment-guides/k8s", "deployment-guides/ecs", "deployment-guides/fly"] + }, + { + "group": "Config as Code", + "icon": "code", "pages": [ - "deployment-guides/k8s", - "deployment-guides/ecs", - "deployment-guides/helm", - "deployment-guides/fly" + { + "group": "Helm", + "icon": "helicopter-symbol", + "pages": [ + "deployment-guides/helm", + "deployment-guides/helm/values", + "deployment-guides/helm/client", + "deployment-guides/helm/providers", + "deployment-guides/helm/storage", + "deployment-guides/helm/plugins", + "deployment-guides/helm/governance", + "deployment-guides/helm/guardrails", + "deployment-guides/helm/cluster", + "deployment-guides/helm/troubleshooting" + ] + }, + { + "group": "config.json", + "icon": "file-code", + "pages": [ + "deployment-guides/config-json", + "deployment-guides/config-json/schema-reference", + "deployment-guides/config-json/client", + "deployment-guides/config-json/providers", + "deployment-guides/config-json/storage", + "deployment-guides/config-json/plugins", + "deployment-guides/config-json/governance", + "deployment-guides/config-json/guardrails" + ] + } ] }, { @@ -430,7 +490,9 @@ { "tab": "Security", "icon": "shield", - "pages": ["security"] + "pages": [ + "security" + ] }, { "tab": "Benchmarks", @@ -453,6 +515,7 @@ "changelogs/v1.5.0-prerelease3", "changelogs/v1.5.0-prerelease2", "changelogs/v1.5.0-prerelease1", + "changelogs/v1.4.23", "changelogs/v1.4.22", "changelogs/v1.4.21", { @@ -590,7 +653,10 @@ }, { "group": "September 2025", - "pages": ["changelogs/v1.2.22", "changelogs/v1.2.21"] + "pages": [ + "changelogs/v1.2.22", + "changelogs/v1.2.21" + ] } ] }, @@ -627,6 +693,10 @@ ] }, "redirects": [ + { + "source": "/features/fallbacks", + "destination": "/features/retries-and-fallbacks" + }, { "source": "/quickstart/gateway/cli-agents", "destination": "/cli-agents/overview" diff --git a/docs/enterprise/vault-support.mdx b/docs/enterprise/vault-support.mdx deleted file mode 100644 index cd63a630ab..0000000000 --- a/docs/enterprise/vault-support.mdx +++ /dev/null @@ -1,133 +0,0 @@ ---- -title: "Vault Support" -description: "Secure API key management with HashiCorp Vault, AWS Secrets Manager, Google Secret Manager, and Azure Key Vault integration. Store and retrieve sensitive credentials using enterprise-grade secret management." -icon: "vault" ---- - -Bifrost's vault support enables seamless integration with enterprise-grade secret management systems, allowing you to connect to existing vaults and automatically sync virtual keys and provider API keys directly onto the Bifrost platform. - -## Overview - -The vault integration provides: - -- **Automated Key Synchronization**: Connect to your existing vault infrastructure and sync all API keys automatically -- **Periodic Key Management**: Regular synchronization ensures deprecated and archived keys are properly managed -- **Multi-Vault Support**: Compatible with HashiCorp Vault, AWS Secrets Manager, Google Secret Manager, and Azure Key Vault -- **Zero-Downtime Operations**: Keys are synced without interrupting your running services - -## Supported Vault Systems - - - - Centralized secret management for self-hosted deployments. - - - Cloud-native secret storage on AWS. - - - Secure key storage on Google Cloud Platform. - - - Key management for Microsoft Azure environments. - - - -## Key Synchronization - -### Automatic Sync Process - -Bifrost automatically synchronizes keys from your vault at regular intervals: - -1. **Discovery**: Scans the configured vault paths for API keys and virtual keys -2. **Validation**: Verifies key format and accessibility -3. **Sync**: Updates Bifrost's internal key store with new and modified keys -4. **Deprecation**: Identifies and archives keys that have been removed from the vault -5. **Notification**: Logs sync status and any issues encountered - -### Sync Configuration - -Configure synchronization behavior to match your operational requirements: - -```json -{ - "vault": { - "sync_interval": "300s", - "sync_paths": [ - "bifrost/provider-keys/*", - "bifrost/virtual-keys/*" - ], - "auto_deprecate": true, - "backup_deprecated_keys": true - } -} -``` - -#### Configuration Options - -| Option | Description | Default | -|--------|-------------|---------| -| `sync_interval` | Time between sync operations | `300s` | -| `sync_paths` | Vault paths to monitor for keys | `["bifrost/*"]` | -| `auto_deprecate` | Automatically deprecate removed keys | `true` | -| `backup_deprecated_keys` | Backup keys before deprecation | `true` | - -## Key Management Lifecycle - -### Key States - -Keys in Bifrost can have the following states: - -- **Active**: Currently in use and available for requests -- **Deprecated**: Marked for removal but still functional -- **Archived**: Removed from active use but retained for audit purposes -- **Expired**: Keys that have exceeded their validity period - -### Deprecation Process - -When keys are removed from the vault: - -1. **Detection**: Next sync cycle identifies missing keys -2. **Grace Period**: Keys enter deprecated state with configurable grace period -3. **Notification**: Administrators are notified of pending deprecation -4. **Archive**: Keys are moved to archived state after grace period expires - -```json -{ - "vault": { - "deprecation": { - "grace_period": "24h", - "notify_admins": true, - "retain_archived": "90d" - } - } -} -``` - -## Security Considerations - -### Authentication - -- **Vault Tokens**: Use time-limited tokens with minimal required permissions -- **IAM Roles**: Leverage cloud provider IAM roles for secure authentication -- **Certificate-based Auth**: Support for mutual TLS authentication where available - -### Encryption - -- **Transit Encryption**: All communication with vault systems uses TLS -- **At-Rest Encryption**: Keys are encrypted in Bifrost's internal storage -- **Key Rotation**: Automatic detection and handling of rotated vault credentials - -### Audit Trail - -Complete audit logging for all vault operations: - -```json -{ - "timestamp": "2024-01-15T10:30:00Z", - "operation": "key_sync", - "vault_type": "hashicorp", - "keys_synced": 15, - "keys_deprecated": 2, - "status": "success" -} -``` diff --git a/docs/features/fallbacks.mdx b/docs/features/fallbacks.mdx deleted file mode 100644 index 23a53c976e..0000000000 --- a/docs/features/fallbacks.mdx +++ /dev/null @@ -1,187 +0,0 @@ ---- -title: "Fallbacks" -description: "Automatic failover between AI providers and models. When your primary provider fails, Bifrost seamlessly switches to backup providers without interrupting your application." -icon: "list-check" ---- - -## Automatic Provider Failover - -Fallbacks provide automatic failover when your primary AI provider experiences issues. Whether it's rate limiting, outages, or model unavailability, Bifrost automatically tries backup providers in the order you specify until one succeeds. - -When a fallback is triggered, Bifrost treats it as a completely new request - all configured plugins (caching, governance, logging, etc.) run again for the fallback provider, ensuring consistent behavior across all providers. - -## How Fallbacks Work - -When you configure fallbacks, Bifrost follows this process: - -1. **Primary Attempt**: Tries your main provider/model first -2. **Automatic Detection**: If the primary fails (network error, rate limit, model unavailable), Bifrost detects the failure -3. **Sequential Fallbacks**: Tries each fallback provider in order until one succeeds -4. **Success Response**: Returns the response from the first successful provider -5. **Complete Failure**: If all providers fail, returns the original error from the primary provider - -Each fallback attempt is treated as a fresh request, so all your configured plugins (semantic caching, governance rules, monitoring) apply to whichever provider ultimately handles the request. - -## Implementation Examples - - - - -```bash -# Chat completion with multiple fallbacks -curl -X POST http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "openai/gpt-4o-mini", - "messages": [ - { - "role": "user", - "content": "Explain quantum computing in simple terms" - } - ], - "fallbacks": [ - "anthropic/claude-3-5-sonnet-20241022", - "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" - ], - "max_tokens": 1000, - "temperature": 0.7 - }' -``` - -**Response (from whichever provider succeeded):** -```json -{ - "id": "chatcmpl-123", - "object": "chat.completion", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Quantum computing is like having a super-powered calculator..." - }, - "finish_reason": "stop" - } - ], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 150, - "total_tokens": 162 - }, - "extra_fields": { - "provider": "anthropic", - "latency": 1.2 - } -} -``` - - - - - -```go -package main - -import ( - "context" - "fmt" - "github.com/maximhq/bifrost" - "github.com/maximhq/bifrost/core/schemas" -) - -func chatWithFallbacks(client *bifrost.Bifrost) { - ctx := context.Background() - - // Chat request with multiple fallbacks - response, err := client.ChatCompletionRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), &schemas.BifrostChatRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: []schemas.ChatMessage{ - { - Role: schemas.ChatMessageRoleUser, - Content: schemas.ChatMessageContent{ - ContentStr: bifrost.Ptr("Explain quantum computing in simple terms"), - }, - }, - }, - // Fallback chain: OpenAI → Anthropic → Bedrock - Fallbacks: []schemas.Fallback{ - { - Provider: schemas.Anthropic, - Model: "claude-3-5-sonnet-20241022", - }, - { - Provider: schemas.Bedrock, - Model: "anthropic.claude-3-sonnet-20240229-v1:0", - }, - }, - Params: &schemas.ChatParameters{ - MaxCompletionTokens: bifrost.Ptr(1000), - Temperature: bifrost.Ptr(0.7), - }, - }) - - if err != nil { - fmt.Printf("All providers failed: %v\n", err) - return - } - - // Success! Response came from whichever provider worked - fmt.Printf("Response from %s: %s\n", - response.ExtraFields.Provider, - *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) -} -``` - - - - - -## Real-World Scenarios - -**Scenario 1: Rate Limiting** -- Primary: OpenAI hits rate limit → Fallback: Anthropic succeeds -- Your application continues without interruption - -**Scenario 2: Model Unavailability** -- Primary: Specific model unavailable → Fallback: Different provider with similar model -- Seamless transition to equivalent capability - -**Scenario 3: Provider Outage** -- Primary: Provider experiencing downtime → Fallback: Alternative provider -- Business continuity maintained - -**Scenario 4: Cost Optimization** -- Primary: Premium model for quality → Fallback: Cost-effective alternative if budget exceeded -- Governance rules can trigger fallbacks based on usage - -## Fallback Behavior Details - -**What Triggers Fallbacks:** -- Network connectivity issues -- Provider API errors (500, 502, 503, 504) -- Rate limiting (429 errors) -- Model unavailability -- Request timeouts -- Authentication failures - -**What Preserves Original Error:** -- Request validation errors (malformed requests) -- Plugin-enforced blocks (governance violations) -- Certain provider-specific errors marked as non-retryable - -**Plugin Execution:** -When a fallback is triggered, the fallback request is treated as completely new: -- Semantic cache checks run again (different provider might have cached responses) -- Governance rules apply to the new provider -- Logging captures the fallback attempt -- All configured plugins execute fresh for the fallback provider - -**Plugin Fallback Control:** -Plugins can control whether fallbacks should be triggered based on their specific logic. For example: -- A custom plugin might prevent fallbacks for certain types of errors -- Security plugins might disable fallbacks for compliance reasons - -When a plugin determines that fallbacks should not be attempted, it can prevent the fallback mechanism entirely, ensuring the original error is returned immediately. - -This ensures consistent behavior regardless of which provider ultimately handles your request, while giving plugins full control over the fallback decision process. And you can always know which provider handled your request via `extra_fields`. diff --git a/docs/features/observability/default.mdx b/docs/features/observability/default.mdx index 4572aec2eb..ad60c254c4 100644 --- a/docs/features/observability/default.mdx +++ b/docs/features/observability/default.mdx @@ -24,12 +24,45 @@ Bifrost traces comprehensive information for every request, without any changes - **Input Messages**: Complete conversation history and user prompts - **Model Parameters**: Temperature, max tokens, tools, and all other parameters - **Provider Context**: Which provider and model handled the request +- **Prompt Tracking**: When the [Prompts plugin](/features/prompt-repository/prompts-plugin) is active, the log captures the selected prompt name, version number, and ID for full traceability ### **Response Data** - **Output Messages**: AI responses, tool calls, and function results - **Performance Metrics**: Latency and token usage - **Status Information**: Success or error details +### **Retry & Key Selection** v1.5.0-prerelease4+ + +When Bifrost retries a request (rate-limit or network error) the following fields are recorded: + +| Field | Meaning | +|-------|---------| +| `selected_key_id` / `selected_key_name` | The API key that **successfully** served the request. `null` when all attempts failed — use `attempt_trail` to see which keys were tried. | +| `number_of_retries` | Total number of attempts minus one. **Does not indicate which key was used on each attempt.** | +| `attempt_trail` | Ordered array of every attempt, with key used and failure reason. `fail_reason` is `null` on the final attempt. | + +**Example `attempt_trail`** — two rate-limit rotations then success on a third key: + +```json +"attempt_trail": [ + { "attempt": 0, "key_id": "key-a", "key_name": "Key A", "fail_reason": "rate_limit_error" }, + { "attempt": 1, "key_id": "key-b", "key_name": "Key B", "fail_reason": "rate_limit_error" }, + { "attempt": 2, "key_id": "key-c", "key_name": "Key C", "fail_reason": null } +] +``` + +Network-error retries reuse the same key; only rate-limit errors rotate to a different key: + +```json +"attempt_trail": [ + { "attempt": 0, "key_id": "key-a", "key_name": "Key A", "fail_reason": "network_error" }, + { "attempt": 1, "key_id": "key-a", "key_name": "Key A", "fail_reason": "rate_limit_error" }, + { "attempt": 2, "key_id": "key-b", "key_name": "Key B", "fail_reason": null } +] +``` + +`attempt_trail` is `null` / absent when the request succeeded on the first try without retries. + ### **Custom Metadata** - **Logging Headers**: Capture configured request headers (e.g., `X-Tenant-ID`) into log metadata - **Ad-hoc Headers**: Any `x-bf-lh-*` prefixed header is automatically captured into metadata diff --git a/docs/features/observability/prometheus.mdx b/docs/features/observability/prometheus.mdx index 6c6df6a24f..3512783e8d 100644 --- a/docs/features/observability/prometheus.mdx +++ b/docs/features/observability/prometheus.mdx @@ -178,6 +178,7 @@ The following metrics are available from both the `/metrics` endpoint and Push G | `bifrost_cache_hits_total` | Counter | Cache hits by type | | `bifrost_stream_first_token_latency_seconds` | Histogram | Time to first token (streaming) | | `bifrost_stream_inter_token_latency_seconds` | Histogram | Inter-token latency (streaming) | +| `bifrost_key_rotation_events_total` | Counter | Per-attempt retry/rotation events with key identifiers (see below) v1.5.0-prerelease4+ | ### Default Labels @@ -187,12 +188,42 @@ All Bifrost metrics include these labels: - `model` - Model identifier - `method` - Request type (chat, completion, embedding, etc.) - `virtual_key_id` / `virtual_key_name` - Virtual key identifiers -- `selected_key_id` / `selected_key_name` - Actual key used -- `number_of_retries` - Retry count +- `selected_key_id` / `selected_key_name` - API key that successfully served the request (`""` when all attempts failed) +- `number_of_retries` - Total attempts minus one (across all keys) - `fallback_index` - Fallback position - `team_id` / `team_name` - Team identifiers (if governance enabled) - `customer_id` / `customer_name` - Customer identifiers (if governance enabled) + + **v1.5.0-prerelease4+**: `selected_key_id` / `selected_key_name` are only populated when the request succeeds. On final errors both are empty — use `bifrost_key_rotation_events_total` or the `attempt_trail` log field to see which keys were tried. + + +### Key Rotation Events v1.5.0-prerelease4+ + +`bifrost_key_rotation_events_total` is incremented once per **failed attempt** (not per request), giving you time-series visibility into retry pressure: + +| Label | Values | Description | +|-------|--------|-------------| +| `provider` | e.g. `openai` | LLM provider | +| `requested_model` | e.g. `gpt-4o` | Model as requested (before any alias resolution) | +| `key_id` | UUID | The provider API key that failed on this attempt | +| `key_name` | string | Human-readable name of the provider API key | +| `fail_reason` | error type string | Provider error type (e.g. `rate_limit_error`, `network_error`) | + +**Example queries:** + +```promql +# Rate-limit events per provider over time +sum by (provider, fail_reason) ( + rate(bifrost_key_rotation_events_total[5m]) +) + +# Which specific keys are hitting rate limits most often +topk(5, sum by (provider, key_name, fail_reason) ( + rate(bifrost_key_rotation_events_total{fail_reason="rate_limit_error"}[1h]) +)) +``` + --- ## Push Gateway Setup diff --git a/docs/features/prompt-repository/playground.mdx b/docs/features/prompt-repository/playground.mdx index 30c2b0df9a..879006c7d2 100644 --- a/docs/features/prompt-repository/playground.mdx +++ b/docs/features/prompt-repository/playground.mdx @@ -85,7 +85,12 @@ The playground uses a simple **three-panel layout**: |------|---------| | **Sidebar (left)** | Browse prompts, manage folders, and organize items | | **Playground (center)** | Build and test your prompt messages | -| **Settings (right)** | Configure provider, model, API key, variables, and parameters | +| **Settings (right)** | Configure provider, model, API key, variables, parameters, and deployments | + +The settings panel is organized into collapsible sections: + +- **Configuration** — Provider, model, API key, variables, and model parameters +- **Deployments** — Prompt deployment strategies and traffic routing (enterprise) ![Workspace Layout](../../media/prompt-repo-layout.png) diff --git a/docs/features/prompt-repository/prompts-plugin.mdx b/docs/features/prompt-repository/prompts-plugin.mdx index 50fd6def32..a06d817245 100644 --- a/docs/features/prompt-repository/prompts-plugin.mdx +++ b/docs/features/prompt-repository/prompts-plugin.mdx @@ -29,13 +29,13 @@ The **Prompts** plugin connects the [Prompt Repository](/features/prompt-reposit ```mermaid flowchart TB Client([Client]) --> Gateway[Bifrost HTTP] - Gateway --> PreHook["HTTP transport pre-hook:
copy bf-prompt-id / bf-prompt-version to context"] + Gateway --> PreHook["HTTP transport pre-hook:
copy x-bf-prompt-id / x-bf-prompt-version to context"] PreHook --> PreLLM["PreLLM hook:
resolve version, merge params,
prepend template messages"] PreLLM --> Provider[Provider] ``` -1. **Transport (HTTP):** Incoming headers `bf-prompt-id` and `bf-prompt-version` are copied onto the Bifrost context (header name matching is case-insensitive). -2. **Resolve:** The plugin looks up the prompt and the requested version. If **`bf-prompt-version` is omitted**, the prompt’s **latest committed version** is used. +1. **Transport (HTTP):** Incoming headers `x-bf-prompt-id` and `x-bf-prompt-version` are copied onto the Bifrost context (header name matching is case-insensitive). +2. **Resolve:** The plugin looks up the prompt and the requested version. If **`x-bf-prompt-version` is omitted**, the prompt’s **latest committed version** is used. 3. **Parameters:** Version `model` parameters are merged into the request; any field already set on the request wins. 4. **Messages:** Messages from the committed version are **prepended** to `messages` (chat) or `input` (responses). Your request body adds the user turn(s) after the template. @@ -47,8 +47,8 @@ If the prompt ID is missing, the plugin does nothing and the request passes thro | Header | Required | Description | |--------|----------|-------------| -| `bf-prompt-id` | Yes, to enable injection | UUID of the prompt in the repository. | -| `bf-prompt-version` | No | **Integer version number** (e.g. `3` for v3). If omitted, the **latest** committed version for that prompt is used. | +| `x-bf-prompt-id` | Yes, to enable injection | UUID of the prompt in the repository. | +| `x-bf-prompt-version` | No | **Integer version number** (e.g. `3` for v3). If omitted, the **latest** committed version for that prompt is used. | Invalid or unknown IDs / versions are logged as warnings; the request is **not** failed by the plugin (it proceeds without template injection). @@ -61,7 +61,7 @@ Use the same JSON body as a normal chat request. Only the headers select the tem ```bash curl -X POST http://localhost:8080/v1/chat/completions \ -H "Content-Type: application/json" \ - -H "bf-prompt-id: YOUR-PROMPT-UUID" \ + -H "x-bf-prompt-id: YOUR-PROMPT-UUID" \ -H "x-bf-vk: sk-bf-your-virtual-key" \ -d '{ "model": "openai/gpt-5.4", @@ -76,13 +76,13 @@ curl -X POST http://localhost:8080/v1/chat/completions \ ![Commit Version with Stream enabled in the playground](../../media/prompt-plugin-version-commit.png) -When you commit a version from the playground, **Stream** is saved in that version’s model parameters. The example `curl` above does not set `"stream": true` in the JSON body, but if the committed version was saved with streaming enabled (as in the screenshot), the merged parameters still include `stream: true`, so the request is handled as **streaming** even though the client did not send `stream` explicitly. +When you commit a version from the playground, the model parameters (temperature, max tokens, etc.) are saved with it. These parameters are merged into the outgoing request, with client-supplied values taking precedence. ![LLM log for the same request showing Type: Chat Stream](../../media/prompt-plugin-llm-log.png) -In **Logs**, that run shows **Type: Chat Stream** and the full conversation: the committed **system** template, your **user** message from the request body, and the assistant reply. +In **Logs**, that run shows the full conversation: the committed **system** template, your **user** message from the request body, and the assistant reply. The log also displays the **Selected Prompt** name and version number for easy traceability. -The provider receives the **stored** messages from the prompt version, checks if the request is streaming or non-streaming, applies the additional model parameters from the request and prepends the messages from the prompt version followed by your user message. +The provider receives the merged model parameters from both the prompt version and the client request, with the messages from the committed version prepended before the client’s messages. --- @@ -91,8 +91,8 @@ The provider receives the **stored** messages from the prompt version, checks if ```bash curl -X POST http://localhost:8080/v1/responses \ -H "Content-Type: application/json" \ - -H "bf-prompt-id: YOUR-PROMPT-UUID" \ - -H "bf-prompt-version: 4" \ + -H "x-bf-prompt-id: YOUR-PROMPT-UUID" \ + -H "x-bf-prompt-version: 4" \ -H "x-bf-vk: sk-bf-your-virtual-key" \ -d '{ "model": "openai/gpt-5-nano-2025-08-07", @@ -104,7 +104,7 @@ curl -X POST http://localhost:8080/v1/responses \ ## Streaming -If the committed version’s **model parameters** include `"stream": true`, the plugin may set streaming on the HTTP transport so behavior matches the saved version. Client-side `stream` flags still interact with the merged parameters as usual. +Streaming is controlled entirely by the client request. If you want streaming, set `"stream": true` in the request body. The plugin merges model parameters from the committed version (request values take precedence), but does **not** override the transport-level streaming mode. --- @@ -118,12 +118,26 @@ The plugin keeps an in-memory cache of prompts and versions (loaded with a small For embedded Bifrost (Go SDK), register the plugin with `prompts.Init` and a **config store** that implements the prompt tables API. The default resolver reads the same logical keys from `BifrostContext`: -- `prompts.PromptIDKey` (`bf-prompt-id`) -- `prompts.PromptVersionKey` (`bf-prompt-version`) +- `prompts.PromptIDKey` (`x-bf-prompt-id`) +- `prompts.PromptVersionKey` (`x-bf-prompt-version`) Set them on the context you pass to `ChatCompletion` / `Responses` if you are not going through the HTTP transport hooks. -For advanced routing (for example, choosing a prompt from governance metadata), implement `prompts.PromptResolver` in `plugins/prompts/main.go` and use **`prompts.InitWithResolver`**. +For advanced routing (for example, choosing a prompt from governance metadata), implement `prompts.PromptResolver` and use **`prompts.InitWithResolver`**. The interface is: + +```go +type PromptResolver interface { + Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, err error) +} +``` + +Return an empty `promptID` to skip injection for a request. Return `versionNumber == 0` to use the prompt's **latest** committed version; any positive integer selects that specific version. + +After injection, the plugin sets the following context keys (read by the logging plugin to populate log fields): + +- `schemas.BifrostContextKeySelectedPromptID` — UUID of the applied prompt +- `schemas.BifrostContextKeySelectedPromptName` — Display name of the prompt +- `schemas.BifrostContextKeySelectedPromptVersion` — Version number as a string (e.g. `"3"`) --- diff --git a/docs/features/retries-and-fallbacks.mdx b/docs/features/retries-and-fallbacks.mdx new file mode 100644 index 0000000000..4d42afafde --- /dev/null +++ b/docs/features/retries-and-fallbacks.mdx @@ -0,0 +1,389 @@ +--- +title: "Retries & Fallbacks" +description: "Automatic retry with exponential backoff and provider failover. Retries handle transient errors within a provider; fallbacks switch to a different provider when all retries are exhausted." +icon: "list-check" +--- + +## Overview + +Bifrost provides two complementary layers of resilience: + +- **Retries** — When a provider returns a transient error (network issue, rate limit, 5xx), Bifrost automatically retries the same request against the same provider with exponential backoff. On rate-limit errors, it can also rotate to a different API key from your pool. +- **Fallbacks** — When the primary provider fails after exhausting all retries, Bifrost moves on to the next provider in your fallback chain. Each fallback provider gets its own full retry budget. + +Together, they let you build LLM-powered applications that stay up through rate limits, transient outages, and even full provider failures — with no changes required in your application code. + +--- + +## Retries + +### How retries work + +When a request fails with a retryable error, Bifrost: + +1. Waits using **exponential backoff with jitter** before the next attempt +2. Retries the request against the same provider +3. On **rate-limit errors** (`429`): rotates to a different API key from the pool (if multiple keys are configured) so fresh capacity is used +4. On **network/server errors** (`5xx`, DNS, connection refused): reuses the same key — these are transient server issues, not per-key capacity problems +5. Continues until the request succeeds or `max_retries` is exhausted + +### Backoff formula + +``` +backoff = min(retry_backoff_initial × 2^attempt, retry_backoff_max) × jitter(0.8–1.2) +``` + +With the defaults of `retry_backoff_initial = 500ms` and `retry_backoff_max = 5000ms`: + +| Attempt | Base backoff | With jitter (approx.) | +|---------|-------------|----------------------| +| 1st retry | 500 ms | 400–600 ms | +| 2nd retry | 1000 ms | 800 ms–1.2 s | +| 3rd retry | 2000 ms | 1.6–2.4 s | +| 4th retry | 4000 ms | 3.2–4.8 s | +| 5th+ retry | 5000 ms (capped) | 4–5 s | + +### What triggers a retry + +| Condition | Retried? | Key rotation? | +|-----------|----------|---------------| +| Network error (DNS, connection refused) | Yes | No — same key reused | +| `5xx` server errors (500, 502, 503, 504) | Yes | No — same key reused | +| Rate limit (`429` or rate-limit message pattern) | Yes | Yes — next key from pool | +| Request validation error | No | — | +| Plugin-enforced block | No | — | +| Cancelled request | No | — | + +### Configuring retries + +Retries are configured per-provider in `network_config`. The defaults are `max_retries: 0` (no retries), `retry_backoff_initial: 500` ms, and `retry_backoff_max: 5000` ms. + + + + +Navigate to **Providers**, select a provider, and open the **Network Config** section. + +Set: +- **Max Retries** — number of additional attempts after the first failure (e.g. `3`) +- **Retry Backoff Initial** — starting backoff in milliseconds (e.g. `500`) +- **Retry Backoff Max** — maximum backoff cap in milliseconds (e.g. `5000`) + + + + +```bash +curl --location 'http://localhost:8080/api/providers' \ +--header 'Content-Type: application/json' \ +--data '{ + "provider": "openai", + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "max_retries": 3, + "retry_backoff_initial": 500, + "retry_backoff_max": 5000 + } +}' +``` + + + + +```go +func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { + switch provider { + case schemas.OpenAI: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + MaxRetries: 3, + RetryBackoffInitial: 500 * time.Millisecond, + RetryBackoffMax: 5 * time.Second, + }, + ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + +```json +{ + "providers": { + "openai": { + "keys": [ + { "name": "openai-key-1", "value": "env.OPENAI_KEY_1", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-2", "value": "env.OPENAI_KEY_2", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-3", "value": "env.OPENAI_KEY_3", "models": ["*"], "weight": 1.0 } + ], + "network_config": { + "max_retries": 3, + "retry_backoff_initial": 500, + "retry_backoff_max": 5000 + } + } + } +} +``` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `max_retries` | integer | `0` | Number of additional attempts after the first failure | +| `retry_backoff_initial` | integer (ms) | `500` | Starting backoff duration in milliseconds | +| `retry_backoff_max` | integer (ms) | `5000` | Maximum backoff cap in milliseconds | + + + + +### Key rotation on rate limits + + +Key rotation on retries requires **v1.5.0-prerelease4 or later**. + + +When you configure multiple API keys for a provider, Bifrost automatically rotates to a fresh key when a rate-limit error is encountered — so retries are not wasted repeating a request with a key that has already hit its limit. + +```json +{ + "providers": { + "openai": { + "keys": [ + { "name": "openai-key-1", "value": "env.OPENAI_KEY_1", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-2", "value": "env.OPENAI_KEY_2", "models": ["*"], "weight": 1.0 }, + { "name": "openai-key-3", "value": "env.OPENAI_KEY_3", "models": ["*"], "weight": 1.0 } + ], + "network_config": { + "max_retries": 5 + } + } + } +} +``` + +With 3 keys and `max_retries: 5`, Bifrost cycles through all three keys twice before giving up. Once all keys in the pool have been tried, it resets and starts a fresh weighted round. + + +Key rotation on rate limits only applies when `max_retries > 0` and more than one key is configured for the provider. With a single key, all retries reuse that key. + + +--- + +## Fallbacks + +Fallbacks provide automatic failover to a different provider when the primary fails after exhausting all its retries. Each fallback is tried in order until one succeeds. + +### How fallbacks work + +1. **Primary attempt**: Tries your configured provider with its full retry budget +2. **Fallback decision**: If the primary fails (and the error is retryable at the provider level), Bifrost moves to the first fallback +3. **Sequential fallbacks**: Each fallback provider also gets its own full retry budget +4. **First success wins**: Returns the response from the first provider that succeeds +5. **All fail**: Returns the original error from the primary provider. Exception: if a plugin on a fallback provider sets `AllowFallbacks = false` on the error (e.g. a security or compliance plugin that should halt the chain regardless of remaining fallbacks), Bifrost stops immediately and returns that fallback's error rather than continuing to the next provider or returning the primary error. + +Each fallback is treated as a completely fresh request — all configured plugins (semantic caching, governance, logging) run again for the fallback provider. + +### Implementation + + + + +Pass a `fallbacks` array in the request body. Each entry specifies a `provider/model` string: + +```bash +curl -X POST http://localhost:8080/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai/gpt-4o-mini", + "messages": [ + { + "role": "user", + "content": "Explain quantum computing in simple terms" + } + ], + "fallbacks": [ + "anthropic/claude-3-5-sonnet-20241022", + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + ], + "max_tokens": 1000, + "temperature": 0.7 + }' +``` + +The response `extra_fields.provider` tells you which provider actually served the request: + +```json +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Quantum computing is like having a super-powered calculator..." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 12, + "completion_tokens": 150, + "total_tokens": 162 + }, + "extra_fields": { + "provider": "anthropic", + "latency": 1.2 + } +} +``` + + + + +```go +package main + +import ( + "context" + "fmt" + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/core/schemas" +) + +func chatWithFallbacks(client *bifrost.Bifrost) { + ctx := context.Background() + + response, err := client.ChatCompletionRequest( + schemas.NewBifrostContext(ctx, schemas.NoDeadline), + &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Explain quantum computing in simple terms"), + }, + }, + }, + // Fallback chain: OpenAI → Anthropic → Bedrock + Fallbacks: []schemas.Fallback{ + {Provider: schemas.Anthropic, Model: "claude-3-5-sonnet-20241022"}, + {Provider: schemas.Bedrock, Model: "anthropic.claude-3-sonnet-20240229-v1:0"}, + }, + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1000), + Temperature: bifrost.Ptr(0.7), + }, + }, + ) + + if err != nil { + fmt.Printf("All providers failed: %v\n", err) + return + } + + fmt.Printf("Response from %s: %s\n", + response.ExtraFields.Provider, + *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) +} +``` + + + + +--- + +## How retries and fallbacks work together + +The two mechanisms form a nested resilience loop. Retries run inside each provider attempt; fallbacks run across providers once retries are exhausted. + +```mermaid +sequenceDiagram + participant App + participant Bifrost + participant Primary as Primary Provider + participant FB1 as Fallback 1 + participant FB2 as Fallback 2 + + App->>Bifrost: Request (primary + fallbacks) + + rect rgb(220, 235, 250) + note over Bifrost,Primary: Primary provider attempt (with retries) + Bifrost->>Primary: Attempt 1 + Primary-->>Bifrost: 429 Rate Limit + note over Bifrost: Backoff + rotate key + Bifrost->>Primary: Attempt 2 (different key) + Primary-->>Bifrost: 503 Unavailable + note over Bifrost: Backoff + Bifrost->>Primary: Attempt 3 + Primary-->>Bifrost: 503 Unavailable + note over Bifrost: max_retries exhausted + end + + rect rgb(235, 250, 220) + note over Bifrost,FB1: Fallback 1 attempt (with its own retries) + Bifrost->>FB1: Attempt 1 + FB1-->>Bifrost: 500 Server Error + note over Bifrost: Backoff + Bifrost->>FB1: Attempt 2 + FB1-->>Bifrost: ✓ Success + end + + Bifrost-->>App: Response (from Fallback 1) +``` + +**Key point:** each provider in the chain — primary and every fallback — gets its own full `max_retries` budget. A primary configured with `max_retries: 3` and two fallbacks each also configured with `max_retries: 3` means up to 12 total attempts before giving up. + + +The retry budget is set per-provider in `network_config`. If your fallback providers have different retry configurations, each will use their own settings. + + +--- + +## Real-world scenarios + +**Scenario 1: Rate limiting with key rotation** + +OpenAI key 1 hits its rate limit. Bifrost rotates to key 2 on the next retry — no fallback needed, the request succeeds within the same provider. + +**Scenario 2: Provider outage** + +OpenAI is experiencing downtime (returning `503`). Bifrost retries with the same key (transient server issue), exhausts `max_retries`, then fails over to Anthropic. Anthropic succeeds on the first attempt. + +**Scenario 3: Cascading failure** + +Both primary and first fallback are down. Bifrost works through each provider's retry budget sequentially until the second fallback succeeds. + +**Scenario 4: Cost-sensitive fallback** + +Primary: a premium model for quality. Fallback: a cost-effective alternative. Governance rules can trigger a budget-exceeded error on the primary, which cascades into the fallback chain. + +--- + +## Plugin execution + +When a fallback is triggered, the fallback request is treated as completely new: + +- Semantic cache checks run again (the fallback provider may have a cached response) +- Governance rules apply to the new provider +- Logging captures the fallback attempt separately +- All configured plugins execute fresh for each provider in the chain + +**Plugin fallback control:** Plugins can prevent fallbacks from being triggered for specific error types. For example, a security plugin might disable fallbacks for compliance reasons. When a plugin sets `AllowFallbacks = false` on the error, the fallback chain is skipped entirely and the original error is returned immediately. + +--- + +## Next steps + +- **[Keys Management](./keys-management)** — Configure multiple API keys per provider to enable key rotation on retries +- **[Governance](./governance/virtual-keys)** — Use virtual keys and routing rules to control which providers are used +- **[Observability](./observability/default)** — Track retry counts and fallback usage in your logs diff --git a/docs/features/telemetry.mdx b/docs/features/telemetry.mdx index 29741cd309..f816e30b4f 100644 --- a/docs/features/telemetry.mdx +++ b/docs/features/telemetry.mdx @@ -65,8 +65,8 @@ Base Labels: - `routing_engines_used`: Comma-separated routing engines used ("routing-rule", "governance", "loadbalancing") - `routing_rule_id`: Routing rule ID that matched the request - `routing_rule_name`: Routing rule name that matched the request -- `selected_key_id`: Selected key ID -- `selected_key_name`: Selected key name +- `selected_key_id`: ID of the key that successfully served the request (`null` on final errors) +- `selected_key_name`: Name of the key that successfully served the request (`null` on final errors) - `number_of_retries`: Number of retries - `fallback_index`: Fallback index (0 for first attempt, 1 for second attempt, etc.) - custom labels: Custom labels configured in the Bifrost configuration diff --git a/docs/integrations/vaults/aws-secrets-manager.mdx b/docs/integrations/vaults/aws-secrets-manager.mdx deleted file mode 100644 index 1913b486c2..0000000000 --- a/docs/integrations/vaults/aws-secrets-manager.mdx +++ /dev/null @@ -1,39 +0,0 @@ ---- -title: "AWS Secrets Manager" -description: "AWS Secrets Manager integration for secret management in Bifrost. Cloud-native API key storage and automatic synchronization with AWS Secrets Manager." -icon: "aws" ---- - -Bifrost integrates with [AWS Secrets Manager](https://aws.amazon.com/secrets-manager/) for cloud-native secret storage, allowing you to store provider API keys and virtual keys in your AWS environment and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to AWS Secrets Manager: - -```json -{ - "vault": { - "type": "aws_secrets_manager", - "region": "us-east-1", - "access_key_id": "${AWS_ACCESS_KEY_ID}", - "secret_access_key": "${AWS_SECRET_ACCESS_KEY}", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"aws_secrets_manager"` to use AWS Secrets Manager. | -| `region` | `string` | The AWS region where your secrets are stored (e.g., `"us-east-1"`). | -| `access_key_id` | `string` | AWS access key ID for authentication. Supports environment variable interpolation via `${AWS_ACCESS_KEY_ID}`. | -| `secret_access_key` | `string` | AWS secret access key for authentication. Supports environment variable interpolation via `${AWS_SECRET_ACCESS_KEY}`. | -| `sync_interval` | `string` | How often Bifrost syncs keys from AWS Secrets Manager. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/integrations/vaults/azure-key-vault.mdx b/docs/integrations/vaults/azure-key-vault.mdx deleted file mode 100644 index 50fe336737..0000000000 --- a/docs/integrations/vaults/azure-key-vault.mdx +++ /dev/null @@ -1,41 +0,0 @@ ---- -title: "Azure Key Vault" -description: "Azure Key Vault integration for secret management in Bifrost. Secure API key storage and automatic synchronization with Microsoft Azure Key Vault." -icon: "microsoft" ---- - -Bifrost integrates with [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault) for secret management in Microsoft cloud environments, allowing you to store provider API keys and virtual keys in your Azure infrastructure and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to Azure Key Vault: - -```json -{ - "vault": { - "type": "azure_key_vault", - "vault_url": "https://your-keyvault.vault.azure.net/", - "client_id": "${AZURE_CLIENT_ID}", - "client_secret": "${AZURE_CLIENT_SECRET}", - "tenant_id": "${AZURE_TENANT_ID}", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"azure_key_vault"` to use Azure Key Vault. | -| `vault_url` | `string` | The full URL of your Azure Key Vault instance (e.g., `"https://your-keyvault.vault.azure.net/"`). | -| `client_id` | `string` | Azure AD application (client) ID for authentication. Supports environment variable interpolation via `${AZURE_CLIENT_ID}`. | -| `client_secret` | `string` | Azure AD client secret for authentication. Supports environment variable interpolation via `${AZURE_CLIENT_SECRET}`. | -| `tenant_id` | `string` | Azure AD tenant ID for authentication. Supports environment variable interpolation via `${AZURE_TENANT_ID}`. | -| `sync_interval` | `string` | How often Bifrost syncs keys from Azure Key Vault. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/integrations/vaults/google-secret-manager.mdx b/docs/integrations/vaults/google-secret-manager.mdx deleted file mode 100644 index f630509512..0000000000 --- a/docs/integrations/vaults/google-secret-manager.mdx +++ /dev/null @@ -1,37 +0,0 @@ ---- -title: "Google Secret Manager" -description: "Google Secret Manager integration for secret management in Bifrost. Secure API key storage and automatic synchronization with Google Cloud's Secret Manager." -icon: "google" ---- - -Bifrost integrates with [Google Cloud Secret Manager](https://cloud.google.com/secret-manager) for secure key storage, allowing you to store provider API keys and virtual keys in your Google Cloud environment and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to Google Secret Manager: - -```json -{ - "vault": { - "type": "google_secret_manager", - "project_id": "your-project-id", - "credentials_file": "/path/to/service-account.json", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"google_secret_manager"` to use Google Secret Manager. | -| `project_id` | `string` | The Google Cloud project ID where your secrets are stored. | -| `credentials_file` | `string` | Path to the Google Cloud service account JSON credentials file used for authentication. | -| `sync_interval` | `string` | How often Bifrost syncs keys from Google Secret Manager. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/integrations/vaults/hashicorp-vault.mdx b/docs/integrations/vaults/hashicorp-vault.mdx deleted file mode 100644 index e91e2e80eb..0000000000 --- a/docs/integrations/vaults/hashicorp-vault.mdx +++ /dev/null @@ -1,39 +0,0 @@ ---- -title: "HashiCorp Vault" -description: "HashiCorp Vault integration for secret management in Bifrost. Centralized API key storage and automatic synchronization with your HashiCorp Vault instance." -icon: "vault" ---- - -Bifrost integrates with [HashiCorp Vault](https://www.vaultproject.io/) for centralized secret management, allowing you to store provider API keys and virtual keys in your existing Vault infrastructure and automatically sync them into Bifrost. - -## Configuration - -Add a `vault` block to your Bifrost configuration to connect to your HashiCorp Vault instance: - -```json -{ - "vault": { - "type": "hashicorp", - "address": "https://vault.company.com:8200", - "token": "${VAULT_TOKEN}", - "mount": "secret", - "sync_interval": "300s" - } -} -``` - -## Configuration Fields - -| Field | Type | Description | -|-------|------|-------------| -| `type` | `string` | Must be set to `"hashicorp"` to use HashiCorp Vault. | -| `address` | `string` | The full URL of your HashiCorp Vault server, including the port. | -| `token` | `string` | Authentication token for Vault access. Supports environment variable interpolation via `${VAULT_TOKEN}`. | -| `mount` | `string` | The secrets engine mount path in Vault (e.g., `"secret"` for the default KV secrets engine). | -| `sync_interval` | `string` | How often Bifrost syncs keys from Vault. Accepts duration strings such as `"300s"`, `"5m"`, or `"1h"`. | - - - The `sync_interval` field controls how frequently Bifrost polls your vault for key changes. Lower intervals detect changes faster but increase load on your vault server. See the [Vault Support](/enterprise/vault-support) page for full sync configuration options including `sync_paths` and `auto_deprecate`. - - -For key synchronization, deprecation management, and security configuration, see [Vault Support](/enterprise/vault-support). diff --git a/docs/media/user-provisioning/zitadel-add-role.png b/docs/media/user-provisioning/zitadel-add-role.png deleted file mode 100644 index f00212f6d2..0000000000 Binary files a/docs/media/user-provisioning/zitadel-add-role.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-add-user-select-key.png b/docs/media/user-provisioning/zitadel-add-user-select-key.png deleted file mode 100644 index ba8d8e52a9..0000000000 Binary files a/docs/media/user-provisioning/zitadel-add-user-select-key.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-auth-method.png b/docs/media/user-provisioning/zitadel-create-app-auth-method.png deleted file mode 100644 index c27a6e5772..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-auth-method.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-namne.png b/docs/media/user-provisioning/zitadel-create-app-namne.png deleted file mode 100644 index 7e220ce193..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-namne.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-create-app-uri.png b/docs/media/user-provisioning/zitadel-create-app-uri.png deleted file mode 100644 index 8796e77ec5..0000000000 Binary files a/docs/media/user-provisioning/zitadel-create-app-uri.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-role-assignemnt.png b/docs/media/user-provisioning/zitadel-role-assignemnt.png deleted file mode 100644 index 5f233eb436..0000000000 Binary files a/docs/media/user-provisioning/zitadel-role-assignemnt.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-select-project.png b/docs/media/user-provisioning/zitadel-select-project.png deleted file mode 100644 index 824c48dd83..0000000000 Binary files a/docs/media/user-provisioning/zitadel-select-project.png and /dev/null differ diff --git a/docs/media/user-provisioning/zitadel-token-config.png b/docs/media/user-provisioning/zitadel-token-config.png deleted file mode 100644 index 1354a62830..0000000000 Binary files a/docs/media/user-provisioning/zitadel-token-config.png and /dev/null differ diff --git a/docs/migration-guides/v1.5.0.mdx b/docs/migration-guides/v1.5.0.mdx index 7d1933aaac..5d31c01b97 100644 --- a/docs/migration-guides/v1.5.0.mdx +++ b/docs/migration-guides/v1.5.0.mdx @@ -462,6 +462,61 @@ result.ResolvedModel // actual model identifier used by the provider --- +## Breaking Change 12: `selected_key_id` Cleared on Terminal Retry Failures + +With the introduction of multi-key retry rotation, `selected_key_id` (and `selected_key_name`) in the request context are **cleared when all retry attempts fail**. Previously, these fields always reflected the key that was selected for the request, even on error. + +The `attempt_trail` is now the authoritative record of every key tried and why each attempt failed. + +### What changed + +| Field | Before | After | +|---|---|---| +| `selected_key_id` | Always set, even on error | Empty string when all retries exhausted | +| `selected_key_name` | Always set, even on error | Empty string when all retries exhausted | +| `attempt_trail` | Not present | Array of `{ key_id, key_name, fail_reason }` per attempt | + +### Impact on logging and telemetry plugins + +The built-in **logging plugin** writes `selected_key_id` and `selected_key_name` directly to each log record. For multi-key requests that exhaust all retries, both fields will be empty in the stored log entry. The `attempt_trail` column captures the full per-attempt key history and is the correct field to use for failure attribution. + +The built-in **telemetry plugin** emits `selected_key_id` and `selected_key_name` as span attributes. For exhausted-retry failures these attributes will be empty strings on the error span. The `attempt_trail` span attribute contains the full rotation history. + +If you run a custom plugin or downstream log consumer that filters or groups by `selected_key_id` to track which key caused a failure, you must update it to handle the empty-string case and read from `attempt_trail` when attribution is needed. + +### How to update + +**If you read `selected_key_id` from plugin context to attribute failed requests:** + +**Before:** +```go +keyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string) +// keyID was always populated, even on error +``` + +**After:** +```go +// Populated on success (or for single-key / pinned / sticky flows on error): +keyID, _ := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string) + +// For full attribution across all retry attempts (including failures): +if trail, ok := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok { + for _, record := range trail { + // record.KeyID, record.KeyName, record.FailReason + } +} +``` + +**If you consume `selected_key_id` from the logging REST API:** + +The `selected_key_id` field on a `LogEntry` may now be an empty string when the request failed after exhausting all retries. Use `attempt_trail` for the full per-attempt key history. + + +Single-key, pinned (`x-bf-key-id` / `x-bf-key-name`), and session-sticky requests are unaffected — they never rotate keys, so `selected_key_id` remains populated on failure for those flows. + + +--- + ## Opting Out: `version: 1` Compatibility Mode If you are not ready to adopt the new deny-by-default semantics, you can add a single field to `config.json` to restore v1.4.x behavior for all allow-list fields loaded from that file: @@ -548,6 +603,10 @@ Replace `ExtraFields.ModelRequested` with `ExtraFields.OriginalModelRequested` ( Replace `.Model` with `.RequestedModel` (and optionally `.ResolvedModel`) on any `StreamAccumulatorResult` usage. + + +If your code reads `selected_key_id` / `selected_key_name` from the request context or log entries to attribute failed requests, add a null/empty check and fall back to `attempt_trail` for the full per-attempt key history. + --- diff --git a/docs/openapi/openapi.json b/docs/openapi/openapi.json index 2a8e0d7a89..84fc1d0584 100644 --- a/docs/openapi/openapi.json +++ b/docs/openapi/openapi.json @@ -28815,6 +28815,10 @@ "type": "integer", "description": "Number of members in the team" }, + "virtual_key_count": { + "type": "integer", + "description": "Number of virtual keys assigned to the team" + }, "created_at": { "type": "string", "format": "date-time" @@ -28977,6 +28981,10 @@ "type": "integer", "description": "Number of members in the team" }, + "virtual_key_count": { + "type": "integer", + "description": "Number of virtual keys assigned to the team" + }, "created_at": { "type": "string", "format": "date-time" @@ -53884,6 +53892,8 @@ }, "virtual_keys": { "type": "array", + "nullable": true, + "description": "Virtual keys assigned to this team. This field may be omitted or returned as null in some responses (for example, when a team is embedded inside a virtual-key response) to avoid nested `virtual_keys` recursion.\n", "items": { "$ref": "#/components/schemas/VirtualKey" } diff --git a/docs/openapi/schemas/management/governance.yaml b/docs/openapi/schemas/management/governance.yaml index bd04aecee7..f057682620 100644 --- a/docs/openapi/schemas/management/governance.yaml +++ b/docs/openapi/schemas/management/governance.yaml @@ -376,6 +376,11 @@ Team: $ref: '#/Budget' virtual_keys: type: array + nullable: true + description: > + Virtual keys assigned to this team. This field may be omitted or returned as null in some + responses (for example, when a team is embedded inside a virtual-key response) to avoid + nested `virtual_keys` recursion. items: $ref: '#/VirtualKey' profile: diff --git a/docs/openapi/schemas/management/users.yaml b/docs/openapi/schemas/management/users.yaml index 46db148f3e..aa9cb504d1 100644 --- a/docs/openapi/schemas/management/users.yaml +++ b/docs/openapi/schemas/management/users.yaml @@ -155,6 +155,9 @@ TeamObject: member_count: type: integer description: Number of members in the team + virtual_key_count: + type: integer + description: Number of virtual keys assigned to the team created_at: type: string format: date-time diff --git a/docs/providers/supported-providers/azure.mdx b/docs/providers/supported-providers/azure.mdx index 46f41a8840..56a576191c 100644 --- a/docs/providers/supported-providers/azure.mdx +++ b/docs/providers/supported-providers/azure.mdx @@ -37,6 +37,544 @@ Azure is a cloud provider offering access to OpenAI and Anthropic models through **Azure-specific**: Batch operations and Text Completions are not supported by Azure OpenAI Service. Responses API uses preview API version and is available for both OpenAI and Anthropic models. +--- + +## Setup & Configuration + +Azure requires an endpoint URL, deployment mappings, and authentication configuration. Four authentication methods are supported. + + +The `aliases` field (mapping model names to Azure deployment IDs) requires **v1.5.0-prerelease2 or later**. On v1.4.x, use `deployments` inside `azure_key_config` instead — see the [v1.5.0 Migration Guide](/migration-guides/v1.5.0#breaking-change-9-provider-deployments-removed-migrate-to-aliases) for details. + + +### 1. Managed Identity + +Leave `value` and all Entra ID fields empty. Bifrost calls `azidentity.NewDefaultAzureCredential(nil)`, which auto-detects the system-assigned or user-assigned managed identity on Azure VMs, App Service, AKS, and Azure Container Instances. No credentials need to be stored or rotated. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Default Credential"** +4. Set **Endpoint**: Your Azure OpenAI resource URL (e.g., `https://your-org.openai.azure.com`) +5. Set **API Version** (Optional): e.g., `2024-10-21` +6. Configure **Aliases**: Map model names to deployment IDs (e.g., `gpt-4o` → `my-gpt4o-deployment`) +7. Save + +Ensure a managed identity with Cognitive Services permissions is attached to the Azure resource running Bifrost. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (Managed Identity — leave value empty) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-managed-identity", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `azure_key_config`: +```json +"azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21", + "deployments": { + "gpt-4o": "my-gpt4o-deployment" + } +} +``` + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-managed-identity", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty — Bifrost uses the managed identity + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 2. Default Credential Chain (DefaultAzureCredential) + +Leave `value` and all Entra ID fields empty. Bifrost calls `azidentity.NewDefaultAzureCredential(nil)` — the same function as the Managed Identity section — which tries credential sources in this order: + +1. Environment variables (`AZURE_CLIENT_ID` + `AZURE_CLIENT_SECRET` + `AZURE_TENANT_ID`, or certificate/username variants) +2. Workload Identity (AKS with Workload Identity Federation) +3. Managed Identity (Azure VMs, App Service, AKS, Container Instances) +4. Azure CLI (`az login`) +5. Azure Developer CLI (`azd auth login`) + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Default Credential"** +4. Set **Endpoint**: Your Azure OpenAI resource URL (e.g., `https://your-org.openai.azure.com`) +5. Set **API Version** (Optional): e.g., `2024-10-21` +6. Configure **Aliases**: Map model names to deployment IDs +7. Save + +Ensure the appropriate credentials are available in the environment where Bifrost runs — set `AZURE_CLIENT_ID`, `AZURE_CLIENT_SECRET`, `AZURE_TENANT_ID` env vars, or run `az login`. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (Default Credential Chain — leave value empty) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-default-chain", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `azure_key_config`. + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-default-chain", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty — Bifrost uses DefaultAzureCredential + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 3. Azure Entra ID (Service Principal) + +Set `client_id`, `client_secret`, and `tenant_id` to authenticate with a Service Principal. This takes priority over API key and managed identity. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Entra ID (Service Principal)"** +4. Set **Client ID**: Your Azure Entra ID client ID +5. Set **Client Secret**: Your Azure Entra ID client secret +6. Set **Tenant ID**: Your Azure Entra ID tenant ID +7. Set **Endpoint**: Your Azure OpenAI resource URL +8. Set **API Version** (Optional): e.g., `2024-08-01-preview` +9. Configure **Aliases**: Map model names to deployment IDs +10. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (Service Principal) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-entra-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + "claude-3-5-sonnet": "my-claude-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "client_id": "env.AZURE_CLIENT_ID", + "client_secret": "env.AZURE_CLIENT_SECRET", + "tenant_id": "env.AZURE_TENANT_ID", + "scopes": ["https://cognitiveservices.azure.com/.default"], + "api_version": "2024-08-01-preview" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Move the model mappings from `aliases` into `azure_key_config.deployments`. + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-entra-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + "claude-3-5-sonnet": "my-claude-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "client_id": "env.AZURE_CLIENT_ID", + "client_secret": "env.AZURE_CLIENT_SECRET", + "tenant_id": "env.AZURE_TENANT_ID", + "scopes": ["https://cognitiveservices.azure.com/.default"], + "api_version": "2024-08-01-preview" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty for Service Principal auth + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + "claude-3-5-sonnet": "my-claude-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + ClientID: schemas.NewEnvVar(os.Getenv("AZURE_CLIENT_ID")), + ClientSecret: schemas.NewEnvVar(os.Getenv("AZURE_CLIENT_SECRET")), + TenantID: schemas.NewEnvVar(os.Getenv("AZURE_TENANT_ID")), + Scopes: []string{"https://cognitiveservices.azure.com/.default"}, + APIVersion: schemas.NewEnvVar("2024-08-01-preview"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +**Required Azure roles:** +- OpenAI models: `Cognitive Services OpenAI User` +- Anthropic models: `Cognitive Services AI Services User` + +### 4. Direct Authentication (API Key) + +Provide the Azure API key in the `value` field. Use this for simple setups without managed identity or Service Principal. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Azure"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"API Key"** +4. Set **API Key**: Your Azure API key +5. Set **Endpoint**: Your Azure OpenAI resource URL +6. Set **API Version** (Optional): e.g., `2024-10-21` +7. Configure **Aliases**: Map model names to deployment IDs +8. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "azure"}' + +# Step 2: Create a key (API Key auth) +curl -X POST http://localhost:8080/api/providers/azure/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "azure-api-key", + "value": "env.AZURE_API_KEY", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Move the model mappings from `aliases` into `azure_key_config.deployments`. + + + + + + +```json +{ + "providers": { + "azure": { + "keys": [ + { + "name": "azure-api-key", + "value": "env.AZURE_API_KEY", + "models": ["*"], + "weight": 1.0, + "aliases": { + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment" + }, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "api_version": "2024-10-21" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `azure_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Azure: + return []schemas.Key{ + { + Value: *schemas.NewEnvVar("env.AZURE_OPENAI_KEY"), + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "gpt-4o": "my-gpt4o-deployment", + "gpt-4o-mini": "my-mini-deployment", + }, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: *schemas.NewEnvVar(os.Getenv("AZURE_ENDPOINT")), + APIVersion: schemas.NewEnvVar("2024-10-21"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + + +**Authentication precedence:** (1) Entra ID if `client_id`, `client_secret`, and `tenant_id` are all set; (2) API key if `value` is non-empty; (3) DefaultAzureCredential (managed identity) if neither is provided. + + +**`azure_key_config` fields:** + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `endpoint` | Yes | — | Azure OpenAI resource endpoint URL | +| `api_version` | No | `2024-10-21` | Azure API version | +| `client_id` | No | — | Entra ID client ID (Service Principal auth) | +| `client_secret` | No | — | Entra ID client secret (Service Principal auth) | +| `tenant_id` | No | — | Entra ID tenant ID (Service Principal auth) | +| `scopes` | No | `["https://cognitiveservices.azure.com/.default"]` | OAuth scopes for token requests | + +**Key-level fields:** + +| Field | Required | Description | +|-------|----------|-------------| +| `aliases` | No | Map model names to Azure deployment IDs (v1.5.0-prerelease2+) | +| `value` | No | Azure API key (leave empty for Entra ID or managed identity) | +| `models` | Yes | Models this key can serve; use `["*"]` to allow all | + +--- + ## Beta Headers For Anthropic models on Azure, Bifrost validates `anthropic-beta` headers and drops unsupported headers from the request. Azure supports most Anthropic beta features. @@ -577,17 +1115,4 @@ Azure routes video generation to OpenAI's Sora models via the Azure OpenAI-compa ## Setup & Configuration -Azure requires endpoint URLs, deployment mappings, and API version configuration. For detailed instructions on setting up Azure authentication, see the quickstart guides: - - - - -See **[Provider-Specific Authentication - Azure](../../quickstart/gateway/provider-configuration#azure)** in the Gateway Quickstart for configuration steps using Web UI, API, or config.json. - - - - -See **[Provider-Specific Authentication - Azure](../../quickstart/go-sdk/provider-configuration#azure)** in the Go SDK Quickstart for programmatic configuration examples. - - - +See the [Setup & Configuration](#setup--configuration) section at the top of this page for authentication instructions and full configuration examples. diff --git a/docs/providers/supported-providers/bedrock.mdx b/docs/providers/supported-providers/bedrock.mdx index 52ff79ca3f..a6bd44bcb7 100644 --- a/docs/providers/supported-providers/bedrock.mdx +++ b/docs/providers/supported-providers/bedrock.mdx @@ -52,6 +52,479 @@ AWS Bedrock supports multiple model families (Claude, Nova, Mistral, Llama, Cohe **Limitations**: Images must be in base64 or data URI format (remote URLs not supported). Text completion streaming is not supported. +--- + +## Setup & Configuration + +Bifrost signs every Bedrock request with AWS Signature Version 4 (SigV4). Four credential methods are supported — choose the one that matches your deployment environment. + + +The `aliases` field (mapping model names to inference profile IDs, ARNs, or deployment identifiers) requires **v1.5.0-prerelease2 or later**. On v1.4.x, use `deployments` inside `bedrock_key_config` instead — see the [v1.5.0 Migration Guide](/migration-guides/v1.5.0#breaking-change-9-provider-deployments-removed-migrate-to-aliases) for details. + + +### 1. Explicit Credentials + +Provide `access_key` and `secret_key` directly. Optionally include `session_token` for temporary credentials. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Explicit Credentials"** +4. Set **Access Key**: Your AWS access key ID +5. Set **Secret Key**: Your AWS secret access key +6. Set **Session Token** (Optional): For temporary/assumed credentials +7. Set **Region**: e.g., `us-east-1` +8. Configure **Aliases**: Map model names to inference profile IDs +9. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (Explicit Credentials) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-key", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "session_token": "env.AWS_SESSION_TOKEN", + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `bedrock_key_config`: +```json +"bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "region": "us-east-1", + "deployments": { + "claude-3-5-sonnet": "arn:aws:bedrock:us-east-1::foundation-model/..." + } +} +``` + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-key", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "access_key": "env.AWS_ACCESS_KEY_ID", + "secret_key": "env.AWS_SECRET_ACCESS_KEY", + "session_token": "env.AWS_SESSION_TOKEN", + "region": "us-east-1" + } + } + ] + } + } +} +``` + + +On **v1.4.x**, use `deployments` inside `bedrock_key_config` instead of the top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + }, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + AccessKey: *schemas.NewEnvVar("env.AWS_ACCESS_KEY_ID"), + SecretKey: *schemas.NewEnvVar("env.AWS_SECRET_ACCESS_KEY"), + SessionToken: schemas.NewEnvVar("env.AWS_SESSION_TOKEN"), + Region: schemas.NewEnvVar("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 2. IAM Role (IRSA / Instance Profile) + +Leave `access_key` and `secret_key` empty. Bifrost calls AWS SDK's `config.LoadDefaultConfig()`, which runs the full credential chain. This section covers IAM-role-based sources: EKS IRSA (`AWS_WEB_IDENTITY_TOKEN_FILE` + `AWS_ROLE_ARN`), ECS task role, and EC2 instance profile (IMDS). No static credentials are stored anywhere. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"IAM Role (Inherited)"** +4. Set **Region**: e.g., `us-east-1` +5. Configure **Aliases** if needed +6. Save + +Ensure your workload has an IAM role with Bedrock permissions attached (via IRSA, ECS task role, or EC2 instance profile). + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (IAM Role — no credentials) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-iam", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `bedrock_key_config`. + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-iam", + "models": ["*"], + "weight": 1.0, + "aliases": { + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + }, + "bedrock_key_config": { + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + // Leave Value empty — Bifrost uses the IAM role bound to the workload + Aliases: schemas.KeyAliases{ + "claude-3-5-sonnet": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + }, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + // Leave AccessKey and SecretKey empty — resolved from IRSA/instance profile + Region: schemas.NewEnvVar("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 3. Default Credential Chain + +Leave `access_key` and `secret_key` empty. Bifrost calls the same `config.LoadDefaultConfig()` as the IAM Role section — the full AWS SDK v2 credential chain runs. Use this when credentials come from environment variables or shared credential files rather than an attached IAM role. + +Full AWS SDK v2 credential resolution order: +1. Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN`) +2. Web Identity Token File / IRSA (`AWS_WEB_IDENTITY_TOKEN_FILE` + `AWS_ROLE_ARN`) +3. Shared credentials and config files (`~/.aws/credentials`, `~/.aws/config`) +4. EC2 instance metadata / IMDS +5. ECS container credentials + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"IAM Role (Inherited)"** — this covers both assigned IAM roles and the default credential chain (env vars, `~/.aws/credentials`) +4. Set **Region**: e.g., `us-east-1` +5. Configure **Aliases** if needed +6. Save + +Ensure `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` are set in the environment where Bifrost runs, or that `~/.aws/credentials` is configured. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (Default Credential Chain — no credentials) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-default-chain", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-default-chain", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + // Leave AccessKey and SecretKey empty — resolved from env vars or ~/.aws/credentials + Region: schemas.NewEnvVar("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 4. STS AssumeRole + +Set `role_arn` to assume an IAM role before signing requests. STS AssumeRole can layer on top of either explicit credentials or the default credential chain. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"AWS Bedrock"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"IAM Role (Inherited)"** or **"Explicit Credentials"** — AssumeRole fields are available on both tabs +4. Set **Region**: e.g., `us-east-1` +5. Set **Assume Role ARN**: The IAM role ARN (e.g., `arn:aws:iam::123456789012:role/BedrockRole`) +6. Set **External ID** (Optional): Required when the role's trust policy demands it +7. Set **Session Name** (Optional): Identifies the session in CloudTrail (default: `bifrost-session`) +8. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "bedrock"}' + +# Step 2: Create a key (STS AssumeRole) +curl -X POST http://localhost:8080/api/providers/bedrock/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "bedrock-assume-role", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "role_arn": "env.AWS_ROLE_ARN", + "external_id": "env.AWS_EXTERNAL_ID", + "session_name": "bifrost-session", + "region": "us-east-1" + } + }' +``` + + +**On v1.4.x**, pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. + + + + + + +```json +{ + "providers": { + "bedrock": { + "keys": [ + { + "name": "bedrock-assume-role", + "models": ["*"], + "weight": 1.0, + "bedrock_key_config": { + "role_arn": "env.AWS_ROLE_ARN", + "external_id": "env.AWS_EXTERNAL_ID", + "session_name": "bifrost-session", + "region": "us-east-1" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Bedrock: + return []schemas.Key{ + { + Models: []string{"*"}, + Weight: 1.0, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + RoleARN: schemas.NewEnvVar("env.AWS_ROLE_ARN"), + ExternalID: schemas.NewEnvVar("env.AWS_EXTERNAL_ID"), // optional + RoleSessionName: schemas.NewEnvVar("bifrost-session"), // optional + Region: bifrost.Ptr("us-east-1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + + +AssumeRole requires a valid source identity — it works when credentials are available via explicit `access_key`/`secret_key` or via the default credential chain. If no credentials are available from either source, AssumeRole will fail. + + +**`bedrock_key_config` fields:** + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `region` | Yes | — | AWS region (e.g., `us-east-1`) | +| `access_key` | No | — | AWS access key ID | +| `secret_key` | No | — | AWS secret access key | +| `session_token` | No | — | AWS session token (for temporary credentials) | +| `arn` | No | — | ARN prefix for constructing inference profile URLs (see [Inference Profiles](#inference-profiles--arn-configuration)) | +| `role_arn` | No | — | IAM role ARN for STS AssumeRole | +| `external_id` | No | — | External ID for AssumeRole (when required by trust policy) | +| `session_name` | No | `bifrost-session` | Session name for AssumeRole CloudTrail logs | + +**Key-level fields:** + +| Field | Required | Description | +|-------|----------|-------------| +| `aliases` | No | Map model names to inference profile IDs or Bedrock model IDs (v1.5.0-prerelease2+) | +| `models` | Yes | Models this key can serve; use `["*"]` to allow all | + +--- + ## Beta Headers For Claude models on Bedrock, Bifrost validates `anthropic-beta` headers and drops unsupported headers from the request. @@ -1271,7 +1744,7 @@ Set `role_arn` to assume an IAM role before signing requests. AssumeRole require | `session_name` | No | `bifrost-session` | Identifies the session in CloudTrail logs | -## Setup & Configuration +## Inference Profiles & ARN Configuration ### How to Use ARNs and Application Inference Profiles @@ -1321,21 +1794,6 @@ When using AWS Bedrock inference profiles or application inference profiles, you } ``` -For detailed instructions on setting up AWS Bedrock authentication including credentials, IAM roles, regions, and deployment mapping, see the quickstart guides: - - - - -See **[Provider-Specific Authentication - AWS Bedrock](../../quickstart/gateway/provider-configuration#aws-bedrock)** in the Gateway Quickstart for configuration steps using Web UI, API, or config.json. - - - - -See **[Provider-Specific Authentication - AWS Bedrock](../../quickstart/go-sdk/provider-configuration#aws-bedrock)** in the Go SDK Quickstart for programmatic configuration examples. - - - - ### Endpoints - **Runtime API**: `bedrock-runtime.{region}.amazonaws.com/model/{path}` diff --git a/docs/providers/supported-providers/vertex.mdx b/docs/providers/supported-providers/vertex.mdx index 4ea108d733..1d19b05984 100644 --- a/docs/providers/supported-providers/vertex.mdx +++ b/docs/providers/supported-providers/vertex.mdx @@ -7,6 +7,7 @@ icon: "v" ## Overview Vertex AI is Google's unified ML platform providing access to Google's Gemini models, Anthropic Claude models, and other third-party LLMs through a single API. Bifrost performs conversions including: + - **Multi-model support** - Unified interface for Gemini, Anthropic, and third-party models - **OAuth2 authentication** - Service account credentials with automatic token refresh - **Project and region management** - Automatic endpoint construction from GCP project/region @@ -17,28 +18,411 @@ Vertex AI is Google's unified ML platform providing access to Google's Gemini mo ### Supported Operations -| Operation | Non-Streaming | Streaming | Endpoint | -|-----------|---------------|-----------|----------| -| Chat Completions | ✅ | ✅ | `/generate` | -| Responses API | ✅ | ✅ | `/messages` | -| Embeddings | ✅ | - | `/embeddings` | -| Image Generation | ✅ | - | `/generateContent` or `/predict` (Imagen) | -| Image Edit | ✅ | - | `/generateContent` or `/predict` (Imagen) | -| Video Generation | ✅ | - | `/predictLongRunning` (Veo models only) | -| Image Variation | ❌ | - | Not supported | -| List Models | ✅ | - | `/models` | -| Text Completions | ❌ | ❌ | - | -| Speech (TTS) | ❌ | ❌ | - | -| Transcriptions (STT) | ❌ | ❌ | - | -| Files | ❌ | ❌ | - | -| Batch | ❌ | ❌ | - | +| Operation | Non-Streaming | Streaming | Endpoint | +| -------------------- | ------------- | --------- | ----------------------------------------- | +| Chat Completions | ✅ | ✅ | `/generate` | +| Responses API | ✅ | ✅ | `/messages` | +| Embeddings | ✅ | - | `/embeddings` | +| Image Generation | ✅ | - | `/generateContent` or `/predict` (Imagen) | +| Image Edit | ✅ | - | `/generateContent` or `/predict` (Imagen) | +| Video Generation | ✅ | - | `/predictLongRunning` (Veo models only) | +| Image Variation | ❌ | - | Not supported | +| List Models | ✅ | - | `/models` | +| Text Completions | ❌ | ❌ | - | +| Speech (TTS) | ❌ | ❌ | - | +| Transcriptions (STT) | ❌ | ❌ | - | +| Files | ❌ | ❌ | - | +| Batch | ❌ | ❌ | - | **Unsupported Operations** (❌): Text Completions, Speech, Transcriptions, Files, and Batch are not supported by Vertex AI. These return `UnsupportedOperationError`. **Vertex-specific**: Endpoints vary by model type. Responses API available for both Gemini and Anthropic models. + + + +--- + +## Setup & Configuration + +Vertex AI requires Google Cloud project configuration and authentication credentials. Three authentication methods are supported. + + + The `aliases` field (mapping model names to fine-tuned model IDs or endpoint + identifiers) requires **v1.5.0-prerelease2 or later**. On v1.4.x, use + `deployments` inside `vertex_key_config` instead — see the [v1.5.0 Migration + Guide](/migration-guides/v1.5.0#breaking-change-9-provider-deployments-removed-migrate-to-aliases) + for details. + + +### 1. Service Account JSON (Recommended for Production) + +Provide a credential JSON string in `auth_credentials`. The JSON must contain a `type` field. Supported types: `service_account` (most common), `impersonated_service_account`, `authorized_user`, `external_account`, `external_account_authorized_user`. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Google Vertex"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Service Account (JSON)"** +4. Set **Project ID**: Your Google Cloud project ID +5. Set **Region**: e.g., `us-central1` +6. Set **Auth Credentials**: Paste your service account JSON or reference an env var (e.g., `env.VERTEX_CREDENTIALS`) +7. Configure **Aliases**: Map model names to fine-tuned model IDs (if using fine-tuned models) +8. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "vertex"}' + +# Step 2: Create a key (Service Account JSON) +curl -X POST http://localhost:8080/api/providers/vertex/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "vertex-sa-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Use `deployments` inside `vertex_key_config` instead of the top-level `aliases` field for fine-tuned model mappings. + + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-sa-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ] + } + } +} +``` + + + On **v1.4.x**, use `deployments` inside `vertex_key_config` instead of the + top-level `aliases` field for fine-tuned model mappings. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, // Leave empty when using service account credentials + Models: []string{"*"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + Region: *schemas.NewEnvVar("us-central1"), + AuthCredentials: *schemas.NewEnvVar("env.VERTEX_CREDENTIALS"), // full service account JSON + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 2. Application Default Credentials + +Leave `auth_credentials` empty. Bifrost calls `google.FindDefaultCredentials()` — Google's ADC library — which resolves credentials in this order: + +1. `GOOGLE_APPLICATION_CREDENTIALS` env var (path to a JSON credential file) +2. Application default credential file (`~/.config/gcloud/application_default_credentials.json`, written by `gcloud auth application-default login`) +3. GCE/GKE/Cloud Run/App Engine metadata server (attached service account or Workload Identity) + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Google Vertex"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"Service Account (Attached)"** +4. Set **Project ID**: Your Google Cloud project ID +5. Set **Region**: e.g., `us-central1` +6. Configure **Aliases** if needed +7. Save + +Ensure `GOOGLE_APPLICATION_CREDENTIALS` is set in your environment, or that Workload Identity / gcloud is configured. + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "vertex"}' + +# Step 2: Create a key (Application Default Credentials) +curl -X POST http://localhost:8080/api/providers/vertex/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "vertex-adc-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "" + } + }' +``` + + +**On v1.4.x**, pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. + + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-adc-key", + "value": "", + "models": ["*"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "" + } + } + ] + } + } +} +``` + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: schemas.EnvVar{}, + Models: []string{"*"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + Region: *schemas.NewEnvVar("us-central1"), + // Leave AuthCredentials empty — uses Application Default Credentials + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + +### 3. API Key (Gemini and Fine-Tuned Models Only) + +Set `value` to your Vertex API key. API key authentication is supported only for Gemini models and fine-tuned Gemini models. For Anthropic models on Vertex, use Service Account or Application Default Credentials. + + + + + +1. Navigate to **"Model Providers"** → **"Configurations"** → **"Google Vertex"** +2. Click **"Add Key"** (or edit an existing key) +3. Under **Authentication Method**, select **"API Key"** +4. Set **API Key**: Your Vertex AI API key +5. Set **Project ID**: Your Google Cloud project ID +6. Set **Region**: e.g., `us-central1` +7. Set **Project Number** (Optional): Required only when using fine-tuned models +8. Configure **Aliases**: Map short names to fine-tuned model IDs (e.g., `my-model` → `123456789`) +9. Save + + + + + +```bash +# Step 1: Create the provider +curl -X POST http://localhost:8080/api/providers \ + -H "Content-Type: application/json" \ + -d '{"provider": "vertex"}' + +# Step 2: Create a key (API Key — Gemini + fine-tuned models) +curl -X POST http://localhost:8080/api/providers/vertex/keys \ + -H "Content-Type: application/json" \ + -d '{ + "name": "vertex-api-key", + "value": "env.VERTEX_API_KEY", + "models": ["gemini-pro", "gemini-2.0-flash", "my-fine-tuned-model"], + "weight": 1.0, + "aliases": { + "my-fine-tuned-model": "123456789" + }, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", + "region": "us-central1" + } + }' +``` + + +**On v1.4.x**, two differences apply: +- Pass `keys` directly in the `POST /api/providers` body — there is no separate `/api/providers/{provider}/keys` endpoint. +- Replace the top-level `aliases` with `"deployments"` inside `vertex_key_config`: +```json +"vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "deployments": { + "my-fine-tuned-model": "123456789" + } +} +``` + + + + + + +```json +{ + "providers": { + "vertex": { + "keys": [ + { + "name": "vertex-api-key", + "value": "env.VERTEX_API_KEY", + "models": ["gemini-pro", "gemini-2.0-flash", "my-fine-tuned-model"], + "weight": 1.0, + "aliases": { + "my-fine-tuned-model": "123456789" + }, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "project_number": "env.VERTEX_PROJECT_NUMBER", + "region": "us-central1" + } + } + ] + } + } +} +``` + + + On **v1.4.x**, use `deployments` inside `vertex_key_config` instead of the + top-level `aliases` field. + + + + + + +```go +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + switch provider { + case schemas.Vertex: + return []schemas.Key{ + { + Value: *schemas.NewEnvVar("env.VERTEX_API_KEY"), // only when using Gemini or fine-tuned models + Models: []string{"gemini-pro", "gemini-2.0-flash", "my-fine-tuned-model"}, + Weight: 1.0, + Aliases: schemas.KeyAliases{ + "my-fine-tuned-model": "123456789", + }, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: *schemas.NewEnvVar("env.VERTEX_PROJECT_ID"), + ProjectNumber: *schemas.NewEnvVar("env.VERTEX_PROJECT_NUMBER"), // required for fine-tuned models + Region: *schemas.NewEnvVar("us-central1"), + }, + }, + }, nil + } + return nil, fmt.Errorf("provider %s not supported", provider) +} +``` + + + + + + + Vertex AI support for fine-tuned models is currently in beta. Requests to + non-Gemini fine-tuned models may fail, so please test and report any issues. +**`vertex_key_config` fields:** + +| Field | Required | Description | +| ------------------ | -------- | ------------------------------------------------------ | +| `project_id` | Yes | Google Cloud project ID | +| `region` | Yes | GCP region (e.g., `us-central1`, `eu-west1`, `global`) | +| `auth_credentials` | No | Service account JSON string (leave empty for ADC) | +| `project_number` | No | GCP project number (required for fine-tuned models) | + +**Key-level fields:** + +| Field | Required | Description | +| --------- | -------- | ----------------------------------------------------------------------------------------- | +| `value` | No | Vertex API key (Gemini and fine-tuned models only; leave empty for Service Account / ADC) | +| `aliases` | No | Map model names to fine-tuned model IDs or endpoint identifiers (v1.5.0-prerelease2+) | +| `models` | Yes | Models this key can serve; use `["*"]` to allow all | + +--- + ## Beta Headers For Anthropic models on Vertex AI, Bifrost validates `anthropic-beta` headers and drops unsupported headers from the request. @@ -50,7 +434,10 @@ For Anthropic models on Vertex AI, Bifrost validates `anthropic-beta` headers an You can override these defaults per provider via the **Beta Headers** tab in provider configuration or via [`beta_header_overrides`](/quickstart/gateway/provider-configuration#beta-header-overrides). See the full support matrix in the [Anthropic provider docs](/providers/supported-providers/anthropic#beta-headers). - Vertex AI Beta Headers configuration tab showing supported and unsupported Anthropic beta features with override options + Vertex AI Beta Headers configuration tab showing supported and unsupported Anthropic beta features with override options --- @@ -61,9 +448,9 @@ You can override these defaults per provider via the **Beta Headers** tab in pro ### Core Parameter Mapping -| Parameter | Vertex Handling | Notes | -|-----------|---|-------| -| `model` | Maps to Vertex model ID | Region-specific endpoint constructed automatically | +| Parameter | Vertex Handling | Notes | +| ---------------- | ------------------------- | ---------------------------------------------------- | +| `model` | Maps to Vertex model ID | Region-specific endpoint constructed automatically | | All other params | Model-specific conversion | Converted per underlying provider (Gemini/Anthropic) | ### Key Configuration @@ -81,6 +468,7 @@ The key configuration for Vertex requires Google Cloud credentials: ``` **Configuration Details**: + - `project_id` - GCP project ID (required) - `region` - GCP region for API endpoints (required) - Examples: `us-central1`, `us-west1`, `eu-west1`, `global` @@ -89,8 +477,9 @@ The key configuration for Vertex requires Google Cloud credentials: ### Authentication Methods 1. **Service Account JSON** (recommended for production) + ```json - {"auth_credentials": "{full-service-account-json}"} + { "auth_credentials": "{full-service-account-json}" } ``` 2. **Application Default Credentials** (for local development) @@ -137,12 +526,12 @@ Refer to [Anthropic documentation](/providers/supported-providers/anthropic) for The region determines the API endpoint: -| Region | Endpoint | Purpose | -|--------|----------|---------| -| `us-central1` | `us-central1-aiplatform.googleapis.com` | US Central | -| `us-west1` | `us-west1-aiplatform.googleapis.com` | US West | -| `eu-west1` | `eu-west1-aiplatform.googleapis.com` | Europe West | -| `global` | `aiplatform.googleapis.com` | Global (no region prefix) | +| Region | Endpoint | Purpose | +| ------------- | --------------------------------------- | ------------------------- | +| `us-central1` | `us-central1-aiplatform.googleapis.com` | US Central | +| `us-west1` | `us-west1-aiplatform.googleapis.com` | US West | +| `eu-west1` | `eu-west1-aiplatform.googleapis.com` | Europe West | +| `global` | `aiplatform.googleapis.com` | Global (no region prefix) | Availability varies by region. Check [GCP documentation](https://cloud.google.com/vertex-ai/docs/general/locations) for model availability. @@ -163,12 +552,12 @@ The Responses API is available for both Anthropic (Claude) and Gemini models on ### Core Parameter Mapping -| Parameter | Vertex Handling | Notes | -|-----------|---|-------| -| `instructions` | Becomes system message | Model-specific conversion | -| `input` | Converted to messages | String or array support | -| `max_output_tokens` | Model-specific field mapping | Gemini vs Anthropic conversion | -| All other params | Model-specific conversion | Converted per underlying provider | +| Parameter | Vertex Handling | Notes | +| ------------------- | ---------------------------- | --------------------------------- | +| `instructions` | Becomes system message | Model-specific conversion | +| `input` | Converted to messages | String or array support | +| `max_output_tokens` | Model-specific field mapping | Gemini vs Anthropic conversion | +| All other params | Model-specific conversion | Converted per underlying provider | ### Gemini Models @@ -177,6 +566,7 @@ For Gemini models, conversion follows Gemini's Responses API format. ### Anthropic Models (Claude) For Anthropic models, conversion follows Anthropic's message format: + - `instructions` becomes system message - `reasoning` mapped to `thinking` structure @@ -234,9 +624,9 @@ Embeddings are supported for Gemini and other models that support embedding gene ### Core Parameters -| Parameter | Vertex Mapping | Notes | -|-----------|---|-------| -| `input` | `instances[].content` | Text to embed | +| Parameter | Vertex Mapping | Notes | +| ------------ | --------------------------------- | -------------------- | +| `input` | `instances[].content` | Text to embed | | `dimensions` | `parameters.outputDimensionality` | Optional output size | ### Advanced Parameters @@ -287,11 +677,11 @@ resp, err := client.EmbeddingRequest(schemas.NewBifrostContext(ctx, schemas.NoDe #### Embedding Parameters -| Parameter | Type | Description | -|-----------|------|-------------| -| `task_type` | string | Task type hint: `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING` (optional) | -| `title` | string | Optional title to help model produce better embeddings (used with task_type) | -| `autoTruncate` | boolean | Auto-truncate input to max tokens (defaults to true) | +| Parameter | Type | Description | +| -------------- | ------- | ------------------------------------------------------------------------------------------------------------------------- | +| `task_type` | string | Task type hint: `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING` (optional) | +| `title` | string | Optional title to help model produce better embeddings (used with task_type) | +| `autoTruncate` | boolean | Auto-truncate input to max tokens (defaults to true) | ### Task Type Effects @@ -322,6 +712,7 @@ Embeddings response includes vectors and truncation information: ``` **Response Fields**: + - `values` - Embedding vector as floats - `statistics.token_count` - Input token count - `statistics.truncated` - Whether input was truncated due to length @@ -336,11 +727,11 @@ Image Generation is supported for Gemini and Imagen on Vertex AI. The provider a ### Core Parameter Mapping -| Parameter | Vertex Handling | Notes | -|-----------|---|-------| -| `model` | Mapped to deployment/model identifier | Model type detected automatically | -| `prompt` | Model-specific conversion | Converted per underlying provider (Gemini/Imagen) | -| All other params | Model-specific conversion | Converted per underlying provider | +| Parameter | Vertex Handling | Notes | +| ---------------- | ------------------------------------- | ------------------------------------------------- | +| `model` | Mapped to deployment/model identifier | Model type detected automatically | +| `prompt` | Model-specific conversion | Converted per underlying provider (Gemini/Imagen) | +| All other params | Model-specific conversion | Converted per underlying provider | ### Model Type Detection @@ -418,29 +809,27 @@ Image generation streaming is not supported by Vertex AI. # 5. Image Edit - -Requests use **multipart/form-data**, not JSON. - +Requests use **multipart/form-data**, not JSON. Image Edit is supported for Gemini and Imagen models on Vertex AI. The provider automatically routes to the appropriate format based on the model type. **Request Parameters** -| Parameter | Type | Required | Notes | -|-----------|------|----------|-------| -| `model` | string | ✅ | Model identifier (must be Gemini or Imagen model) | -| `prompt` | string | ✅ | Text description of the edit | -| `image[]` | binary | ✅ | Image file(s) to edit (supports multiple images) | -| `mask` | binary | ❌ | Mask image file | -| `type` | string | ❌ | Edit type: `"inpainting"`, `"outpainting"`, `"inpaint_removal"`, `"bgswap"` (Imagen only) | -| `n` | int | ❌ | Number of images to generate (1-10) | -| `output_format` | string | ❌ | Output format: `"png"`, `"webp"`, `"jpeg"` | -| `output_compression` | int | ❌ | Compression level (0-100%) | -| `seed` | int | ❌ | Seed for reproducibility (via `ExtraParams["seed"]`) | -| `negative_prompt` | string | ❌ | Negative prompt (via `ExtraParams["negativePrompt"]`) | -| `maskMode` | string | ❌ | Mask mode (via `ExtraParams["maskMode"]`, Imagen only): `"MASK_MODE_USER_PROVIDED"`, `"MASK_MODE_BACKGROUND"`, `"MASK_MODE_FOREGROUND"`, `"MASK_MODE_SEMANTIC"` | -| `dilation` | float | ❌ | Mask dilation (via `ExtraParams["dilation"]`, Imagen only): Range [0, 1] | -| `maskClasses` | int[] | ❌ | Mask classes (via `ExtraParams["maskClasses"]`, Imagen only): For `MASK_MODE_SEMANTIC` | +| Parameter | Type | Required | Notes | +| -------------------- | ------ | -------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model` | string | ✅ | Model identifier (must be Gemini or Imagen model) | +| `prompt` | string | ✅ | Text description of the edit | +| `image[]` | binary | ✅ | Image file(s) to edit (supports multiple images) | +| `mask` | binary | ❌ | Mask image file | +| `type` | string | ❌ | Edit type: `"inpainting"`, `"outpainting"`, `"inpaint_removal"`, `"bgswap"` (Imagen only) | +| `n` | int | ❌ | Number of images to generate (1-10) | +| `output_format` | string | ❌ | Output format: `"png"`, `"webp"`, `"jpeg"` | +| `output_compression` | int | ❌ | Compression level (0-100%) | +| `seed` | int | ❌ | Seed for reproducibility (via `ExtraParams["seed"]`) | +| `negative_prompt` | string | ❌ | Negative prompt (via `ExtraParams["negativePrompt"]`) | +| `maskMode` | string | ❌ | Mask mode (via `ExtraParams["maskMode"]`, Imagen only): `"MASK_MODE_USER_PROVIDED"`, `"MASK_MODE_BACKGROUND"`, `"MASK_MODE_FOREGROUND"`, `"MASK_MODE_SEMANTIC"` | +| `dilation` | float | ❌ | Mask dilation (via `ExtraParams["dilation"]`, Imagen only): Range [0, 1] | +| `maskClasses` | int[] | ❌ | Mask classes (via `ExtraParams["maskClasses"]`, Imagen only): For `MASK_MODE_SEMANTIC` | --- @@ -454,6 +843,7 @@ Vertex uses the same conversion functions as Gemini: **Model Validation**: Only Gemini and Imagen models are supported. Other models return `ConfigurationError`. **Request Body Processing**: + - All request bodies are converted to `map[string]interface{}` for Vertex API compatibility - The `region` field is removed before sending to Vertex API - For Gemini models, unsupported fields are stripped via `stripVertexGeminiUnsupportedFields()` (removes `id` from function_call and function_response) @@ -466,6 +856,7 @@ Vertex uses the same conversion functions as Gemini: **Endpoint Selection** The provider automatically selects the endpoint based on model type: + - **Gemini models**: `/v1/projects/{projectID}/locations/{region}/publishers/google/models/{model}:generateContent` - **Imagen models**: `/v1/projects/{projectID}/locations/{region}/publishers/google/models/{model}:predict` @@ -509,7 +900,9 @@ Lists models available in the specified project and region with metadata and dep ## Custom vs Non-Custom Models -**Important**: Vertex AI's List Models API **only returns custom fine-tuned models** that have been deployed to your project. It does NOT return standard foundation models (Gemini, Claude, etc.). + **Important**: Vertex AI's List Models API **only returns custom fine-tuned + models** that have been deployed to your project. It does NOT return standard + foundation models (Gemini, Claude, etc.). To provide a complete model listing experience, Bifrost performs **multi-pass model discovery**: @@ -622,53 +1015,49 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p ## Caveats -**Severity**: High -**Behavior**: Both project_id and region required for all operations -**Impact**: Request fails without valid GCP project/region configuration -**Code**: `vertex.go:127-138` + **Severity**: High **Behavior**: Both project_id and region required for all + operations **Impact**: Request fails without valid GCP project/region + configuration **Code**: `vertex.go:127-138` -**Severity**: Medium -**Behavior**: Tokens cached and automatically refreshed when expired -**Impact**: First request slightly slower due to auth; cached for subsequent requests -**Code**: `vertex.go:34-55` + **Severity**: Medium **Behavior**: Tokens cached and automatically refreshed + when expired **Impact**: First request slightly slower due to auth; cached for + subsequent requests **Code**: `vertex.go:34-55` -**Severity**: Medium -**Behavior**: Automatic detection of Anthropic vs Gemini models -**Impact**: Different conversion logic applied transparently -**Code**: `vertex.go` chat/responses endpoints + **Severity**: Medium **Behavior**: Automatic detection of Anthropic vs Gemini + models **Impact**: Different conversion logic applied transparently **Code**: + `vertex.go` chat/responses endpoints -**Severity**: Low -**Behavior**: Responses API automatically routes to Anthropic or Gemini implementation based on model -**Impact**: Different conversion logic applied transparently per model -**Code**: `vertex.go:836-1080` + **Severity**: Low **Behavior**: Responses API automatically routes to + Anthropic or Gemini implementation based on model **Impact**: Different + conversion logic applied transparently per model **Code**: + `vertex.go:836-1080` -**Severity**: Low -**Behavior**: `anthropic_version` always set to `vertex-2023-10-16` for Claude -**Impact**: Cannot override Anthropic version for Claude on Vertex -**Code**: `utils.go:33, 71` + **Severity**: Low **Behavior**: `anthropic_version` always set to + `vertex-2023-10-16` for Claude **Impact**: Cannot override Anthropic version + for Claude on Vertex **Code**: `utils.go:33, 71` -**Severity**: Low -**Behavior**: Vertex returns float64 embeddings, and Bifrost preserves that precision in normalized embedding responses -**Impact**: No precision loss in the `/v1/embeddings` response path -**Code**: `embedding.go:84-91` + **Severity**: Low **Behavior**: Vertex returns float64 embeddings, and Bifrost + preserves that precision in normalized embedding responses **Impact**: No + precision loss in the `/v1/embeddings` response path **Code**: + `embedding.go:84-91` -**Severity**: High -**Behavior**: Vertex AI's List Models API only returns custom fine-tuned models, NOT foundation models -**Impact**: Bifrost performs three-pass discovery to include foundation models from aliases and the key-level `models` allowlist -**Why**: This is a Vertex AI API limitation - foundation models must be explicitly configured -**Code**: `models.go:76-217` + **Severity**: High **Behavior**: Vertex AI's List Models API only returns + custom fine-tuned models, NOT foundation models **Impact**: Bifrost performs + three-pass discovery to include foundation models from aliases and the + key-level `models` allowlist **Why**: This is a Vertex AI API limitation - + foundation models must be explicitly configured **Code**: `models.go:76-217` --- @@ -683,40 +1072,22 @@ Model listing is paginated automatically. If more than 100 models exist, `next_p **Note**: For `global` region, endpoint is `https://aiplatform.googleapis.com/v1/projects/{project}/locations/global/{resource}` -## Setup & Configuration - -Vertex AI requires project configuration, region selection, and Google Cloud authentication. For detailed instructions on setting up Vertex AI, see the quickstart guides: - - - - -See **[Provider-Specific Authentication - Google Vertex](../../quickstart/gateway/provider-configuration#google-vertex)** in the Gateway Quickstart for configuration steps using Web UI, API, or config.json. - - - - -See **[Provider-Specific Authentication - Google Vertex](../../quickstart/go-sdk/provider-configuration#google-vertex)** in the Go SDK Quickstart for programmatic configuration examples. - - - - ---- - ## Video Generation Vertex AI routes video generation through Gemini's Veo models using the `predictLongRunning` endpoint. All parameters are identical to [Gemini Video Generation](/providers/supported-providers/gemini#video-generation). -Only Veo models are supported (e.g., `veo-2.0-generate-001`). Passing a non-Veo model name returns a configuration error. + Only Veo models are supported (e.g., `veo-2.0-generate-001`). Passing a + non-Veo model name returns a configuration error. **Supported Operations** -| Operation | Supported | Notes | -|-----------|-----------|-------| -| Generate | ✅ | `POST /v1/videos` | -| Retrieve | ✅ | `GET /v1/videos/{id}` | -| Download | ✅ | `GET /v1/videos/{id}/content` | -| Delete | ❌ | Not supported | -| List | ❌ | Not supported | -| Remix | ❌ | Not supported | +| Operation | Supported | Notes | +| --------- | --------- | ----------------------------- | +| Generate | ✅ | `POST /v1/videos` | +| Retrieve | ✅ | `GET /v1/videos/{id}` | +| Download | ✅ | `GET /v1/videos/{id}/content` | +| Delete | ❌ | Not supported | +| List | ❌ | Not supported | +| Remix | ❌ | Not supported | diff --git a/docs/quickstart/gateway/provider-configuration.mdx b/docs/quickstart/gateway/provider-configuration.mdx index 597d3b2a1e..db259873f0 100644 --- a/docs/quickstart/gateway/provider-configuration.mdx +++ b/docs/quickstart/gateway/provider-configuration.mdx @@ -84,60 +84,56 @@ curl --location 'http://localhost:8080/api/providers' \ - -Each key in a provider needs to have a unique name. - - +Each key in a provider needs to have a unique name. ```json { - "providers": { - "openai": { - "keys": [ - { - "name": "openai-key", - "value": "env.OPENAI_API_KEY", - "models": ["*"], - "weight": 1.0 - } - ] - }, - "anthropic": { - "keys": [ - { - "name": "anthropic-key", - "value": "env.ANTHROPIC_API_KEY", - "models": ["*"], - "weight": 1.0 - } - ] - }, - "vllm-local": { - "keys": [ - { - "name": "vllm-key", - "value": "dummy", - "models": ["*"], - "weight": 1.0 - } - ], - "network_config": { - "base_url": "http://vllm-endpoint:8000", - "default_request_timeout_in_seconds": 60 - }, - "custom_provider_config": { - "base_provider_type": "openai", - "allowed_requests": { - "chat_completion": true, - "chat_completion_stream": true - } - } + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key", + "value": "env.OPENAI_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "anthropic": { + "keys": [ + { + "name": "anthropic-key", + "value": "env.ANTHROPIC_API_KEY", + "models": ["*"], + "weight": 1.0 + } + ] + }, + "vllm-local": { + "keys": [ + { + "name": "vllm-key", + "value": "dummy", + "models": ["*"], + "weight": 1.0 + } + ], + "network_config": { + "base_url": "http://vllm-endpoint:8000", + "default_request_timeout_in_seconds": 60 + }, + "custom_provider_config": { + "base_provider_type": "openai", + "allowed_requests": { + "chat_completion": true, + "chat_completion_stream": true } + } } + } } ``` -
@@ -147,7 +143,12 @@ Each key in a provider needs to have a unique name. -**Air-gapped or self-signed certificates:** If your custom provider uses HTTPS with a self-signed or internal CA certificate, add `"insecure_skip_verify": true` or `"ca_cert_pem": "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----"` to `network_config`. See [Custom Providers - TLS](../../providers/custom-providers#tls-for-self-signed-or-internal-certificates) for details. + **Air-gapped or self-signed certificates:** If your custom provider uses HTTPS + with a self-signed or internal CA certificate, add `"insecure_skip_verify": + true` or `"ca_cert_pem": "-----BEGIN CERTIFICATE-----\n...\n-----END + CERTIFICATE-----"` to `network_config`. See [Custom Providers - + TLS](../../providers/custom-providers#tls-for-self-signed-or-internal-certificates) + for details. ## Making Requests @@ -179,6 +180,7 @@ export COHERE_API_KEY="your-cohere-api-key" ``` **Environment Variable Handling:** + - Use `"value": "env.VARIABLE_NAME"` to reference environment variables - Use `"value": "sk-proj-xxxxxxxxx"` to pass keys directly - All sensitive data is automatically redacted in GET requests and UI responses for security @@ -218,7 +220,7 @@ curl --location 'http://localhost:8080/api/providers' \ }, { "name": "openai-key-2", - "value": "env.OPENAI_API_KEY_2", + "value": "env.OPENAI_API_KEY_2", "models": ["*"], "weight": 0.3 } @@ -232,24 +234,24 @@ curl --location 'http://localhost:8080/api/providers' \ ```json { - "providers": { - "openai": { - "keys": [ - { - "name": "openai-key-1", - "value": "env.OPENAI_API_KEY_1", - "models": ["*"], - "weight": 0.7 - }, - { - "name": "openai-key-2", - "value": "env.OPENAI_API_KEY_2", - "models": ["*"], - "weight": 0.3 - } - ] + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY_1", + "models": ["*"], + "weight": 0.7 + }, + { + "name": "openai-key-2", + "value": "env.OPENAI_API_KEY_2", + "models": ["*"], + "weight": 0.3 } + ] } + } } ``` @@ -304,24 +306,24 @@ curl --location 'http://localhost:8080/api/providers' \ ```json { - "providers": { - "openai": { - "keys": [ - { - "name": "openai-key-1", - "value": "env.OPENAI_API_KEY", - "models": ["gpt-4o", "gpt-4o-mini"], - "weight": 1.0 - }, - { - "name": "openai-key-2", - "value": "env.OPENAI_API_KEY_PREMIUM", - "models": ["o1-preview", "o1-mini"], - "weight": 1.0 - } - ] + "providers": { + "openai": { + "keys": [ + { + "name": "openai-key-1", + "value": "env.OPENAI_API_KEY", + "models": ["gpt-4o", "gpt-4o-mini"], + "weight": 1.0 + }, + { + "name": "openai-key-2", + "value": "env.OPENAI_API_KEY_PREMIUM", + "models": ["o1-preview", "o1-mini"], + "weight": 1.0 } + ] } + } } ``` @@ -337,7 +339,6 @@ Override the default API endpoint for a provider. This is useful for connecting -